diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index cff4221abe..3243269378 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -1,23 +1,43 @@ -ai-create-chapter-chunks: - @echo "Creating chapter chunks" - @CMD="python manage.py ai_create_chapter_chunks" $(MAKE) exec-backend-command +ai-run-rag-tool: + @echo "Running RAG tool" + @CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command -ai-create-committee-chunks: - @echo "Creating committee chunks" - @CMD="python manage.py ai_create_committee_chunks" $(MAKE) exec-backend-command +ai-update-chapter-chunks: + @echo "Updating chapter chunks" + @CMD="python manage.py ai_update_chapter_chunks" $(MAKE) exec-backend-command -ai-create-event-chunks: - @echo "Creating event chunks" - @CMD="python manage.py ai_create_event_chunks" $(MAKE) exec-backend-command +ai-update-chapter-context: + @echo "Updating chapter context" + @CMD="python manage.py ai_update_chapter_context" $(MAKE) exec-backend-command -ai-create-project-chunks: - @echo "Creating project chunks" - @CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command +ai-update-committee-chunks: + @echo "Updating committee chunks" + @CMD="python manage.py ai_update_committee_chunks" $(MAKE) exec-backend-command -ai-create-slack-message-chunks: - @echo "Creating Slack message chunks" - @CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command +ai-update-committee-context: + @echo "Updating committee context" + @CMD="python manage.py ai_update_committee_context" $(MAKE) exec-backend-command -ai-run-rag-tool: - @echo "Running RAG tool" - @CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command +ai-update-event-chunks: + @echo "Updating event chunks" + @CMD="python manage.py ai_update_event_chunks" $(MAKE) exec-backend-command + +ai-update-event-context: + @echo "Updating event context" + @CMD="python manage.py ai_update_event_context" $(MAKE) exec-backend-command + +ai-update-project-chunks: + @echo "Updating project chunks" + @CMD="python manage.py ai_update_project_chunks" $(MAKE) exec-backend-command + +ai-update-project-context: + @echo "Updating project context" + @CMD="python manage.py ai_update_project_context" $(MAKE) exec-backend-command + +ai-update-slack-message-chunks: + @echo "Updating Slack message chunks" + @CMD="python manage.py ai_update_slack_message_chunks" $(MAKE) exec-backend-command + +ai-update-slack-message-context: + @echo "Updating Slack message context" + @CMD="python manage.py ai_update_slack_message_context" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/admin.py b/backend/apps/ai/admin.py index a7240e8115..d0852aeb48 100644 --- a/backend/apps/ai/admin.py +++ b/backend/apps/ai/admin.py @@ -3,6 +3,7 @@ from django.contrib import admin from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context class ChunkAdmin(admin.ModelAdmin): @@ -11,9 +12,25 @@ class ChunkAdmin(admin.ModelAdmin): list_display = ( "id", "text", - "content_type", + "context", ) - search_fields = ("text", "object_id") + list_filter = ("context__entity_type",) + search_fields = ("text",) + + +class ContextAdmin(admin.ModelAdmin): + """Admin for Context model.""" + + list_display = ( + "id", + "content", + "entity_type", + "entity_id", + "source", + ) + list_filter = ("entity_type", "source") + search_fields = ("content", "source") admin.site.register(Chunk, ChunkAdmin) +admin.site.register(Context, ContextAdmin) diff --git a/backend/apps/ai/agent/tools/rag/rag_tool.py b/backend/apps/ai/agent/tools/rag/rag_tool.py index 6072de8cbb..8375b4a328 100644 --- a/backend/apps/ai/agent/tools/rag/rag_tool.py +++ b/backend/apps/ai/agent/tools/rag/rag_tool.py @@ -28,12 +28,8 @@ def __init__( ValueError: If the OpenAI API key is not set. """ - try: - self.retriever = Retriever(embedding_model=embedding_model) - self.generator = Generator(chat_model=chat_model) - except Exception: - logger.exception("Failed to initialize RAG tool") - raise + self.retriever = Retriever(embedding_model=embedding_model) + self.generator = Generator(chat_model=chat_model) def query( self, diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index a4ed638ef6..0a7619796c 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -21,7 +21,13 @@ class Retriever: """A class for retrieving relevant text chunks for a RAG.""" - SUPPORTED_CONTENT_TYPES = ["event", "project", "chapter", "committee", "message"] + SUPPORTED_ENTITY_TYPES = ( + "chapter", + "committee", + "event", + "message", + "project", + ) def __init__(self, embedding_model: str = "text-embedding-3-small"): """Initialize the Retriever. @@ -36,7 +42,6 @@ def __init__(self, embedding_model: str = "text-embedding-3-small"): if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" raise ValueError(error_msg) - self.openai_client = openai.OpenAI(api_key=openai_api_key) self.embedding_model = embedding_model logger.info("Retriever initialized with embedding model: %s", self.embedding_model) @@ -64,121 +69,116 @@ def get_query_embedding(self, query: str) -> list[float]: logger.exception("Unexpected error while generating embedding") raise - def get_source_name(self, content_object) -> str: + def get_source_name(self, entity) -> str: """Get the name/identifier for the content object.""" for attr in ("name", "title", "login", "key", "summary"): - if getattr(content_object, attr, None): - return str(getattr(content_object, attr)) - - return str(content_object) + if getattr(entity, attr, None): + return str(getattr(entity, attr)) + return str(entity) - def get_additional_context(self, content_object, content_type: str) -> dict[str, Any]: + def get_additional_context(self, entity) -> dict[str, Any]: """Get additional context information based on content type. Args: - content_object: The source object. - content_type: The model name of the content object. + entity: The source object. Returns: A dictionary with additional context information. """ context = {} - clean_content_type = content_type.split(".")[-1] if "." in content_type else content_type - + clean_content_type = entity.__class__.__name__.lower() if clean_content_type == "chapter": context.update( { - "location": getattr(content_object, "suggested_location", None), - "region": getattr(content_object, "region", None), - "country": getattr(content_object, "country", None), - "postal_code": getattr(content_object, "postal_code", None), - "currency": getattr(content_object, "currency", None), - "meetup_group": getattr(content_object, "meetup_group", None), - "tags": getattr(content_object, "tags", []), - "topics": getattr(content_object, "topics", []), - "leaders": getattr(content_object, "leaders_raw", []), - "related_urls": getattr(content_object, "related_urls", []), - "is_active": getattr(content_object, "is_active", None), - "url": getattr(content_object, "url", None), + "location": getattr(entity, "suggested_location", None), + "region": getattr(entity, "region", None), + "country": getattr(entity, "country", None), + "postal_code": getattr(entity, "postal_code", None), + "currency": getattr(entity, "currency", None), + "meetup_group": getattr(entity, "meetup_group", None), + "tags": getattr(entity, "tags", []), + "topics": getattr(entity, "topics", []), + "leaders": getattr(entity, "leaders_raw", []), + "related_urls": getattr(entity, "related_urls", []), + "is_active": getattr(entity, "is_active", None), + "url": getattr(entity, "url", None), } ) elif clean_content_type == "project": context.update( { - "level": getattr(content_object, "level", None), - "project_type": getattr(content_object, "type", None), - "languages": getattr(content_object, "languages", []), - "topics": getattr(content_object, "topics", []), - "licenses": getattr(content_object, "licenses", []), - "tags": getattr(content_object, "tags", []), - "custom_tags": getattr(content_object, "custom_tags", []), - "stars_count": getattr(content_object, "stars_count", None), - "forks_count": getattr(content_object, "forks_count", None), - "contributors_count": getattr(content_object, "contributors_count", None), - "releases_count": getattr(content_object, "releases_count", None), - "open_issues_count": getattr(content_object, "open_issues_count", None), - "leaders": getattr(content_object, "leaders_raw", []), - "related_urls": getattr(content_object, "related_urls", []), - "created_at": getattr(content_object, "created_at", None), - "updated_at": getattr(content_object, "updated_at", None), - "released_at": getattr(content_object, "released_at", None), - "health_score": getattr(content_object, "health_score", None), - "is_active": getattr(content_object, "is_active", None), - "track_issues": getattr(content_object, "track_issues", None), - "url": getattr(content_object, "url", None), + "level": getattr(entity, "level", None), + "project_type": getattr(entity, "type", None), + "languages": getattr(entity, "languages", []), + "topics": getattr(entity, "topics", []), + "licenses": getattr(entity, "licenses", []), + "tags": getattr(entity, "tags", []), + "custom_tags": getattr(entity, "custom_tags", []), + "stars_count": getattr(entity, "stars_count", None), + "forks_count": getattr(entity, "forks_count", None), + "contributors_count": getattr(entity, "contributors_count", None), + "releases_count": getattr(entity, "releases_count", None), + "open_issues_count": getattr(entity, "open_issues_count", None), + "leaders": getattr(entity, "leaders_raw", []), + "related_urls": getattr(entity, "related_urls", []), + "created_at": getattr(entity, "created_at", None), + "updated_at": getattr(entity, "updated_at", None), + "released_at": getattr(entity, "released_at", None), + "health_score": getattr(entity, "health_score", None), + "is_active": getattr(entity, "is_active", None), + "track_issues": getattr(entity, "track_issues", None), + "url": getattr(entity, "url", None), } ) elif clean_content_type == "event": context.update( { - "start_date": getattr(content_object, "start_date", None), - "end_date": getattr(content_object, "end_date", None), - "location": getattr(content_object, "suggested_location", None), - "category": getattr(content_object, "category", None), - "latitude": getattr(content_object, "latitude", None), - "longitude": getattr(content_object, "longitude", None), - "url": getattr(content_object, "url", None), - "description": getattr(content_object, "description", None), - "summary": getattr(content_object, "summary", None), + "start_date": getattr(entity, "start_date", None), + "end_date": getattr(entity, "end_date", None), + "location": getattr(entity, "suggested_location", None), + "category": getattr(entity, "category", None), + "latitude": getattr(entity, "latitude", None), + "longitude": getattr(entity, "longitude", None), + "url": getattr(entity, "url", None), + "description": getattr(entity, "description", None), + "summary": getattr(entity, "summary", None), } ) elif clean_content_type == "committee": context.update( { - "is_active": getattr(content_object, "is_active", None), - "leaders": getattr(content_object, "leaders", []), - "url": getattr(content_object, "url", None), - "description": getattr(content_object, "description", None), - "summary": getattr(content_object, "summary", None), - "tags": getattr(content_object, "tags", []), - "topics": getattr(content_object, "topics", []), - "related_urls": getattr(content_object, "related_urls", []), + "is_active": getattr(entity, "is_active", None), + "leaders": getattr(entity, "leaders", []), + "url": getattr(entity, "url", None), + "description": getattr(entity, "description", None), + "summary": getattr(entity, "summary", None), + "tags": getattr(entity, "tags", []), + "topics": getattr(entity, "topics", []), + "related_urls": getattr(entity, "related_urls", []), } ) elif clean_content_type == "message": context.update( { "channel": ( - getattr(content_object.conversation, "slack_channel_id", None) - if hasattr(content_object, "conversation") and content_object.conversation + getattr(entity.conversation, "slack_channel_id", None) + if hasattr(entity, "conversation") and entity.conversation else None ), "thread_ts": ( - getattr(content_object.parent_message, "ts", None) - if hasattr(content_object, "parent_message") - and content_object.parent_message + getattr(entity.parent_message, "ts", None) + if hasattr(entity, "parent_message") and entity.parent_message else None ), - "ts": getattr(content_object, "ts", None), + "ts": getattr(entity, "ts", None), "user": ( - getattr(content_object.author, "name", None) - if hasattr(content_object, "author") and content_object.author + getattr(entity.author, "name", None) + if hasattr(entity, "author") and entity.author else None ), } ) - return {k: v for k, v in context.items() if v is not None} def retrieve( @@ -201,14 +201,11 @@ def retrieve( """ query_embedding = self.get_query_embedding(query) - if not content_types: content_types = self.extract_content_types_from_query(query) - queryset = Chunk.objects.annotate( similarity=1 - CosineDistance("embedding", query_embedding) ).filter(similarity__gte=similarity_threshold) - if content_types: content_type_query = Q() for name in content_types: @@ -216,36 +213,31 @@ def retrieve( if "." in lower_name: app_label, model = lower_name.split(".", 1) content_type_query |= Q( - content_type__app_label=app_label, content_type__model=model + context__entity_type__app_label=app_label, + context__entity_type__model=model, ) else: - content_type_query |= Q(content_type__model=lower_name) + content_type_query |= Q(context__entity_type__model=lower_name) queryset = queryset.filter(content_type_query) - chunks = ( - queryset.select_related("content_type") - .prefetch_related("content_object") - .order_by("-similarity")[:limit] - ) + chunks = queryset.select_related("context__entity_type").order_by("-similarity")[:limit] results = [] for chunk in chunks: - if not chunk.content_object: + if not chunk.context or not chunk.context.entity: logger.warning("Content object is None for chunk %s. Skipping.", chunk.id) continue - source_name = self.get_source_name(chunk.content_object) - additional_context = self.get_additional_context( - chunk.content_object, chunk.content_type.model - ) + source_name = self.get_source_name(chunk.context.entity) + additional_context = self.get_additional_context(chunk.context.entity) results.append( { "text": chunk.text, "similarity": float(chunk.similarity), - "source_type": chunk.content_type.model, + "source_type": chunk.context.entity_type.model, "source_name": source_name, - "source_id": chunk.object_id, + "source_id": chunk.context.entity_id, "additional_context": additional_context, } ) @@ -262,13 +254,12 @@ def extract_content_types_from_query(self, query: str) -> list[str]: A list of detected content type names. """ - detected_types = [] query_words = set(re.findall(r"\b\w+\b", query.lower())) detected_types = [ - content_type - for content_type in self.SUPPORTED_CONTENT_TYPES - if content_type in query_words or f"{content_type}s" in query_words + entity_type + for entity_type in self.SUPPORTED_ENTITY_TYPES + if entity_type in query_words or f"{entity_type}s" in query_words ] if detected_types: diff --git a/backend/apps/ai/common/base/__init__.py b/backend/apps/ai/common/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/ai/common/base/ai_command.py b/backend/apps/ai/common/base/ai_command.py new file mode 100644 index 0000000000..30337603e6 --- /dev/null +++ b/backend/apps/ai/common/base/ai_command.py @@ -0,0 +1,109 @@ +"""Base AI command class with common functionality.""" + +import os +from collections.abc import Callable +from typing import Any + +import openai +from django.core.management.base import BaseCommand +from django.db.models import Model, QuerySet + + +class BaseAICommand(BaseCommand): + """Base class for AI management commands with common functionality.""" + + model_class: type[Model] + key_field_name: str + + def __init__(self, *args, **kwargs): + """Initialize the AI command with OpenAI client placeholder.""" + super().__init__(*args, **kwargs) + self.openai_client: openai.OpenAI | None = None + self.entity_name = self.model_class.__name__.lower() + self.entity_name_plural = self.model_class.__name__.lower() + "s" + + def source_name(self) -> str: + """Return the source name for context creation. Override if different from default.""" + return f"owasp_{self.entity_name}" + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset. Override for custom filtering logic.""" + return self.model_class.objects.all() + + def get_default_queryset(self) -> QuerySet: + """Return the default queryset when no specific options are provided.""" + return self.get_base_queryset().filter(is_active=True) + + def add_common_arguments(self, parser): + """Add common arguments that most commands need.""" + parser.add_argument( + f"--{self.entity_name}-key", + type=str, + help=f"Process only the {self.entity_name} with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help=f"Process all the {self.entity_name_plural}", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help=f"Number of {self.entity_name_plural} to process in each batch", + ) + + def add_arguments(self, parser): + """Add arguments to the command. Override to add custom arguments.""" + self.add_common_arguments(parser) + + def get_queryset(self, options: dict[str, Any]) -> QuerySet: + """Get the queryset based on command options.""" + key_option = f"{self.entity_name}_key" + + if options.get(key_option): + filter_kwargs = {self.key_field_name: options[key_option]} + return self.get_base_queryset().filter(**filter_kwargs) + if options.get("all"): + return self.get_base_queryset() + return self.get_default_queryset() + + def get_entity_key(self, entity: Model) -> str: + """Get the key/identifier for an entity for display purposes.""" + return str(getattr(entity, self.key_field_name, entity.pk)) + + def setup_openai_client(self) -> bool: + """Set up OpenAI client if API key is available.""" + if openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY"): + self.openai_client = openai.OpenAI(api_key=openai_api_key) + return True + self.stdout.write( + self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") + ) + return False + + def handle_batch_processing( + self, + queryset: QuerySet, + batch_size: int, + process_batch_func: Callable[[list[Model]], int], + ) -> None: + """Handle the common batch processing logic.""" + total_count = queryset.count() + + if not total_count: + self.stdout.write(f"No {self.entity_name_plural} found to process") + return + + self.stdout.write(f"Found {total_count} {self.entity_name_plural} to process") + + processed_count = 0 + for offset in range(0, total_count, batch_size): + batch_items = queryset[offset : offset + batch_size] + processed_count += process_batch_func(list(batch_items)) + + self.stdout.write( + self.style.SUCCESS( + f"Completed processing {processed_count}/{total_count} {self.entity_name_plural}" + ) + ) diff --git a/backend/apps/ai/common/base/chunk_command.py b/backend/apps/ai/common/base/chunk_command.py new file mode 100644 index 0000000000..6a29005d12 --- /dev/null +++ b/backend/apps/ai/common/base/chunk_command.py @@ -0,0 +1,91 @@ +"""Base chunk command class for creating chunks.""" + +from django.contrib.contenttypes.models import ContentType +from django.db.models import Max, Model + +from apps.ai.common.base.ai_command import BaseAICommand +from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context + + +class BaseChunkCommand(BaseAICommand): + """Base class for chunk creation commands.""" + + def help(self) -> str: + """Return help text for the chunk creation command.""" + return f"Create or update chunks for OWASP {self.entity_name} data" + + def process_chunks_batch(self, entities: list[Model]) -> int: + """Process a batch of entities to create or update chunks.""" + processed = 0 + batch_chunks_to_create = [] + content_type = ContentType.objects.get_for_model(self.model_class) + + for entity in entities: + entity_key = self.get_entity_key(entity) + context = Context.objects.filter(entity_type=content_type, entity_id=entity.id).first() + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for {self.entity_name} {entity_key}") + ) + continue + + latest_chunk_timestamp = context.chunks.aggregate( + latest_created=Max("nest_created_at") + )["latest_created"] + + if not latest_chunk_timestamp or context.nest_updated_at > latest_chunk_timestamp: + self.stdout.write(f"Context for {entity_key} requires chunk creation/update") + + if latest_chunk_timestamp: + count, _ = context.chunks.all().delete() + self.stdout.write(f"Deleted {count} stale chunks for {entity_key}") + + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content to chunk for {self.entity_name} {entity_key}") + continue + + chunk_texts = Chunk.split_text(full_content) + if not chunk_texts: + self.stdout.write(f"No chunks created for {self.entity_name} {entity_key}") + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks_to_create.extend(chunks) + processed += 1 + self.stdout.write( + self.style.SUCCESS(f"Created {len(chunks)} new chunks for {entity_key}") + ) + else: + self.stdout.write(f"Chunks for {entity_key} are already up to date.") + + if batch_chunks_to_create: + Chunk.bulk_save(batch_chunks_to_create) + + return processed + + def handle(self, *args, **options): + """Handle the chunk creation command.""" + if not self.setup_openai_client(): + return + + queryset = self.get_queryset(options) + batch_size = options["batch_size"] + + self.handle_batch_processing( + queryset=queryset, + batch_size=batch_size, + process_batch_func=self.process_chunks_batch, + ) diff --git a/backend/apps/ai/common/base/context_command.py b/backend/apps/ai/common/base/context_command.py new file mode 100644 index 0000000000..5a43370450 --- /dev/null +++ b/backend/apps/ai/common/base/context_command.py @@ -0,0 +1,54 @@ +"""Base context command class for creating contexts.""" + +from django.db.models import Model + +from apps.ai.common.base.ai_command import BaseAICommand +from apps.ai.models.context import Context + + +class BaseContextCommand(BaseAICommand): + """Base class for context creation commands.""" + + def help(self) -> str: + """Return help text for the context creation command.""" + return f"Update context for OWASP {self.entity_name} data" + + def process_context_batch(self, entities: list[Model]) -> int: + """Process a batch of entities to create contexts.""" + processed = 0 + + for entity in entities: + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + entity_key = self.get_entity_key(entity) + self.stdout.write(f"No content for {self.entity_name} {entity_key}") + continue + + if Context.update_data( + content=full_content, + entity=entity, + source=self.source_name(), + ): + processed += 1 + entity_key = self.get_entity_key(entity) + self.stdout.write(f"Created context for {entity_key}") + else: + entity_key = self.get_entity_key(entity) + self.stdout.write(self.style.ERROR(f"Failed to create context for {entity_key}")) + + return processed + + def handle(self, *args, **options): + """Handle the context creation command.""" + queryset = self.get_queryset(options) + batch_size = options["batch_size"] + + self.handle_batch_processing( + queryset=queryset, + batch_size=batch_size, + process_batch_func=self.process_context_batch, + ) diff --git a/backend/apps/ai/common/constants.py b/backend/apps/ai/common/constants.py index 207b53599c..cce67fc739 100644 --- a/backend/apps/ai/common/constants.py +++ b/backend/apps/ai/common/constants.py @@ -2,6 +2,6 @@ DEFAULT_LAST_REQUEST_OFFSET_SECONDS = 2 DEFAULT_CHUNKS_RETRIEVAL_LIMIT = 5 -DEFAULT_SIMILARITY_THRESHOLD = 0.5 +DEFAULT_SIMILARITY_THRESHOLD = 0.4 DELIMITER = "\n\n" MIN_REQUEST_INTERVAL_SECONDS = 1.2 diff --git a/backend/apps/ai/common/extractors/__init__.py b/backend/apps/ai/common/extractors/__init__.py new file mode 100644 index 0000000000..8a8422e9cc --- /dev/null +++ b/backend/apps/ai/common/extractors/__init__.py @@ -0,0 +1 @@ +"""Common extractors for AI content processing.""" diff --git a/backend/apps/ai/common/extractors/chapter.py b/backend/apps/ai/common/extractors/chapter.py new file mode 100644 index 0000000000..0bacfcff2e --- /dev/null +++ b/backend/apps/ai/common/extractors/chapter.py @@ -0,0 +1,76 @@ +"""Content extractor for Chapter.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_chapter_content(chapter) -> tuple[str, str]: + """Extract structured content from chapter data. + + Args: + chapter: Chapter instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if chapter.description: + prose_parts.append(f"Description: {chapter.description}") + + if chapter.summary: + prose_parts.append(f"Summary: {chapter.summary}") + + if chapter.owasp_repository: + repo = chapter.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if chapter.name: + metadata_parts.append(f"Chapter Name: {chapter.name}") + + location_parts = [] + if chapter.country: + location_parts.append(f"Country: {chapter.country}") + if chapter.region: + location_parts.append(f"Region: {chapter.region}") + if chapter.postal_code: + location_parts.append(f"Postal Code: {chapter.postal_code}") + if chapter.suggested_location: + location_parts.append(f"Location: {chapter.suggested_location}") + + if location_parts: + metadata_parts.append(f"Location Information: {', '.join(location_parts)}") + + if chapter.currency: + metadata_parts.append(f"Currency: {chapter.currency}") + + if chapter.meetup_group: + metadata_parts.append(f"Meetup Group: {chapter.meetup_group}") + + if chapter.tags: + metadata_parts.append(f"Tags: {', '.join(chapter.tags)}") + + if chapter.topics: + metadata_parts.append(f"Topics: {', '.join(chapter.topics)}") + + if chapter.leaders_raw: + metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") + + if chapter.related_urls: + invalid_urls = getattr(chapter, "invalid_urls", []) or [] + valid_urls = [url for url in chapter.related_urls if url and url not in invalid_urls] + + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + if chapter.is_active: + metadata_parts.append("Active Chapter: Yes") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/extractors/committee.py b/backend/apps/ai/common/extractors/committee.py new file mode 100644 index 0000000000..e3603b6dbb --- /dev/null +++ b/backend/apps/ai/common/extractors/committee.py @@ -0,0 +1,55 @@ +"""Context extractor for Committee.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_committee_content(committee) -> tuple[str, str]: + """Extract structured content from committee data. + + Args: + committee: Committee instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if committee.description: + prose_parts.append(f"Description: {committee.description}") + + if committee.summary: + prose_parts.append(f"Summary: {committee.summary}") + + if committee.owasp_repository: + repo = committee.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if committee.name: + metadata_parts.append(f"Committee Name: {committee.name}") + + if committee.tags: + metadata_parts.append(f"Tags: {', '.join(committee.tags)}") + + if committee.topics: + metadata_parts.append(f"Topics: {', '.join(committee.topics)}") + + if committee.leaders_raw: + metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") + + if committee.related_urls: + invalid_urls = getattr(committee, "invalid_urls", []) or [] + valid_urls = [url for url in committee.related_urls if url and url not in invalid_urls] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + metadata_parts.append(f"Active Committee: {'Yes' if committee.is_active else 'No'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/extractors/event.py b/backend/apps/ai/common/extractors/event.py new file mode 100644 index 0000000000..851a10439e --- /dev/null +++ b/backend/apps/ai/common/extractors/event.py @@ -0,0 +1,49 @@ +"""Content extractor for Event.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_event_content(event) -> tuple[str, str]: + """Extract structured content from event data. + + Args: + event: Event instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if event.description: + prose_parts.append(f"Description: {event.description}") + + if event.summary: + prose_parts.append(f"Summary: {event.summary}") + + if event.name: + metadata_parts.append(f"Event Name: {event.name}") + + if event.category: + metadata_parts.append(f"Category: {event.get_category_display()}") + + if event.start_date: + metadata_parts.append(f"Start Date: {event.start_date}") + + if event.end_date: + metadata_parts.append(f"End Date: {event.end_date}") + + if event.suggested_location: + metadata_parts.append(f"Location: {event.suggested_location}") + + if event.latitude is not None and event.longitude is not None: + metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") + + if event.url: + metadata_parts.append(f"Event URL: {event.url}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/extractors/project.py b/backend/apps/ai/common/extractors/project.py new file mode 100644 index 0000000000..c6b8ee6209 --- /dev/null +++ b/backend/apps/ai/common/extractors/project.py @@ -0,0 +1,97 @@ +"""Content extractor for Project.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_project_content(project) -> tuple[str, str]: + """Extract structured content from project data. + + Args: + project: Project instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if project.description: + prose_parts.append(f"Description: {project.description}") + + if project.summary: + prose_parts.append(f"Summary: {project.summary}") + + if project.owasp_repository: + repo = project.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if project.name: + metadata_parts.append(f"Project Name: {project.name}") + + if project.level: + metadata_parts.append(f"Project Level: {project.level}") + + if project.type: + metadata_parts.append(f"Project Type: {project.type}") + + if project.languages: + metadata_parts.append(f"Programming Languages: {', '.join(project.languages)}") + + if project.topics: + metadata_parts.append(f"Topics: {', '.join(project.topics)}") + + if project.licenses: + metadata_parts.append(f"Licenses: {', '.join(project.licenses)}") + + if project.tags: + metadata_parts.append(f"Tags: {', '.join(project.tags)}") + + if project.custom_tags: + metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") + + stats_parts = [] + if project.stars_count: + stats_parts.append(f"Stars: {project.stars_count}") + if project.forks_count: + stats_parts.append(f"Forks: {project.forks_count}") + if project.contributors_count: + stats_parts.append(f"Contributors: {project.contributors_count}") + if project.releases_count: + stats_parts.append(f"Releases: {project.releases_count}") + if project.open_issues_count: + stats_parts.append(f"Open Issues: {project.open_issues_count}") + + if stats_parts: + metadata_parts.append("Project Statistics: " + ", ".join(stats_parts)) + + if project.leaders_raw: + metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") + + if project.related_urls: + invalid_urls = getattr(project, "invalid_urls", []) or [] + valid_urls = [url for url in project.related_urls if url and url not in invalid_urls] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + if project.created_at: + metadata_parts.append(f"Created: {project.created_at.strftime('%Y-%m-%d')}") + + if project.updated_at: + metadata_parts.append(f"Last Updated: {project.updated_at.strftime('%Y-%m-%d')}") + + if project.released_at: + metadata_parts.append(f"Last Release: {project.released_at.strftime('%Y-%m-%d')}") + + if project.health_score is not None: + metadata_parts.append(f"Health Score: {project.health_score:.2f}") + + metadata_parts.append(f"Active Project: {'Yes' if project.is_active else 'No'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index c0824760a1..8a258f7f3b 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -4,31 +4,44 @@ import time from datetime import UTC, datetime, timedelta +from openai import OpenAI, OpenAIError + from apps.ai.common.constants import ( DEFAULT_LAST_REQUEST_OFFSET_SECONDS, MIN_REQUEST_INTERVAL_SECONDS, ) from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context logger: logging.Logger = logging.getLogger(__name__) def create_chunks_and_embeddings( - all_chunk_texts: list[str], - content_object, + chunk_texts: list[str], + context: Context, openai_client, + model: str = "text-embedding-3-small", + *, + save: bool = True, ) -> list[Chunk]: """Create chunks and embeddings from given texts using OpenAI embeddings. Args: - all_chunk_texts (list[str]): List of text chunks to embed. - content_object: The object to associate the chunks with. - openai_client: Initialized OpenAI client instance. + chunk_texts (list[str]): List of text chunks to process + context (Context): The context these chunks belong to + openai_client: Initialized OpenAI client + model (str): Embedding model to use + save (bool): Whether to save chunks immediately Returns: - list[Chunk]: List of Chunk instances (not saved). + list[Chunk]: List of created Chunk instances (empty if failed) + + Raises: + ValueError: If context is None or invalid """ + from apps.ai.models.chunk import Chunk + try: last_request_time = datetime.now(UTC) - timedelta( seconds=DEFAULT_LAST_REQUEST_OFFSET_SECONDS @@ -39,26 +52,47 @@ def create_chunks_and_embeddings( time.sleep(MIN_REQUEST_INTERVAL_SECONDS - time_since_last_request.total_seconds()) response = openai_client.embeddings.create( - input=all_chunk_texts, - model="text-embedding-3-small", + input=chunk_texts, + model=model, ) + embeddings = [d.embedding for d in response.data] + + chunks = [] + for text, embedding in zip(chunk_texts, embeddings, strict=True): + chunk = Chunk.update_data(text=text, embedding=embedding, context=context, save=save) + if chunk is not None: + chunks.append(chunk) - return [ - chunk - for text, embedding in zip( - all_chunk_texts, - [d.embedding for d in response.data], - strict=True, - ) - if ( - chunk := Chunk.update_data( - text=text, - content_object=content_object, - embedding=embedding, - save=False, - ) - ) - ] - except Exception: - logger.exception("OpenAI API error") + except (OpenAIError, AttributeError, TypeError): + logger.exception("Failed to create chunks and embeddings") return [] + else: + return chunks + + +def regenerate_chunks_for_context(context: Context): + """Regenerates all chunks for a single, specific context instance. + + Args: + context (Context): The specific context instance to be updated. + + """ + from apps.ai.models.chunk import Chunk + + context.chunks.all().delete() + new_chunk_texts = Chunk.split_text(context.content) + + if not new_chunk_texts: + logger.warning("No content to chunk for Context. Process stopped.") + return + + openai_client = OpenAI() + + create_chunks_and_embeddings( + chunk_texts=new_chunk_texts, + context=context, + openai_client=openai_client, + save=True, + ) + + logger.info("Successfully completed chunk regeneration for new context") diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py deleted file mode 100644 index 8b73079e64..0000000000 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ /dev/null @@ -1,161 +0,0 @@ -"""A command to create chunks of OWASP chapter data for RAG.""" - -import os - -import openai -from django.core.management.base import BaseCommand - -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.owasp.models.chapter import Chapter - - -class Command(BaseCommand): - help = "Create chunks for OWASP chapter data" - - def add_arguments(self, parser): - parser.add_argument( - "--chapter", - type=str, - help="Process only the chapter with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the chapters", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of chapters to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if chapter := options["chapter"]: - queryset = Chapter.objects.filter(key=chapter) - elif options["all"]: - queryset = Chapter.objects.all() - else: - queryset = Chapter.objects.filter(is_active=True) - - if not (total_chapters := queryset.count()): - self.stdout.write("No chapters found to process") - return - - self.stdout.write(f"Found {total_chapters} chapters to process") - - batch_size = options["batch_size"] - for offset in range(0, total_chapters, batch_size): - batch_chapters = queryset[offset : offset + batch_size] - - batch_chunks = [] - for chapter in batch_chapters: - batch_chunks.extend(self.handle_chunks(chapter)) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {len(batch_chunks)} chunks") - - self.stdout.write(f"Completed processing all {total_chapters} chapters") - - def handle_chunks(self, chapter: Chapter) -> list[Chunk]: - """Create chunks from a chapter's data.""" - prose_content, metadata_content = self.extract_chapter_content(chapter) - - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) - - if prose_content.strip(): - all_chunk_texts.extend(Chunk.split_text(prose_content)) - - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for chapter {chapter.key}") - return [] - - return create_chunks_and_embeddings( - all_chunk_texts=all_chunk_texts, - content_object=chapter, - openai_client=self.openai_client, - ) - - def extract_chapter_content(self, chapter: Chapter) -> tuple[str, str]: - """Extract and separate prose content from metadata for a chapter. - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if chapter.description: - prose_parts.append(f"Description: {chapter.description}") - - if chapter.summary: - prose_parts.append(f"Summary: {chapter.summary}") - - if hasattr(chapter, "owasp_repository") and chapter.owasp_repository: - repo = chapter.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if chapter.name: - metadata_parts.append(f"Chapter Name: {chapter.name}") - - location_parts = [] - if chapter.country: - location_parts.append(f"Country: {chapter.country}") - if chapter.region: - location_parts.append(f"Region: {chapter.region}") - if chapter.postal_code: - location_parts.append(f"Postal Code: {chapter.postal_code}") - if chapter.suggested_location: - location_parts.append(f"Location: {chapter.suggested_location}") - - if location_parts: - metadata_parts.append(f"Location Information: {', '.join(location_parts)}") - - if chapter.currency: - metadata_parts.append(f"Currency: {chapter.currency}") - - if chapter.meetup_group: - metadata_parts.append(f"Meetup Group: {chapter.meetup_group}") - - if chapter.tags: - metadata_parts.append(f"Tags: {', '.join(chapter.tags)}") - - if chapter.topics: - metadata_parts.append(f"Topics: {', '.join(chapter.topics)}") - - if chapter.leaders_raw: - metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") - - if chapter.related_urls: - valid_urls = [ - url - for url in chapter.related_urls - if url and url not in (chapter.invalid_urls or []) - ] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - metadata_parts.append(f"Active Chapter: {'Yes' if chapter.is_active else 'No'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py deleted file mode 100644 index 6ae3771bc6..0000000000 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ /dev/null @@ -1,137 +0,0 @@ -"""A command to create chunks of OWASP committee data for RAG.""" - -import os - -import openai -from django.core.management.base import BaseCommand - -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.owasp.models.committee import Committee - - -class Command(BaseCommand): - help = "Create chunks for OWASP committee data" - - def add_arguments(self, parser): - parser.add_argument( - "--committee", - type=str, - help="Process only the committee with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the committees", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of committees to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if committee := options["committee"]: - queryset = Committee.objects.filter(key=committee) - elif options["all"]: - queryset = Committee.objects.all() - else: - queryset = Committee.objects.filter(is_active=True) - - if not (total_committees := queryset.count()): - self.stdout.write("No committees found to process") - return - - self.stdout.write(f"Found {total_committees} committees to process") - - batch_size = options["batch_size"] - for offset in range(0, total_committees, batch_size): - batch_committees = queryset[offset : offset + batch_size] - - batch_chunks = [] - for committee in batch_committees: - batch_chunks.extend(self.handle_chunks(committee)) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {len(batch_chunks)} chunks") - - self.stdout.write(f"Completed processing all {total_committees} committees") - - def handle_chunks(self, committee: Committee) -> list[Chunk]: - """Create chunks from a committee's data.""" - prose_content, metadata_content = self.extract_committee_content(committee) - - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) - - if prose_content.strip(): - all_chunk_texts.extend(Chunk.split_text(prose_content)) - - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for committee {committee.key}") - return [] - - return create_chunks_and_embeddings( - all_chunk_texts=all_chunk_texts, - content_object=committee, - openai_client=self.openai_client, - ) - - def extract_committee_content(self, committee: Committee) -> tuple[str, str]: - """Extract structured content from committee data.""" - prose_parts = [] - metadata_parts = [] - - if committee.description: - prose_parts.append(f"Description: {committee.description}") - - if committee.summary: - prose_parts.append(f"Summary: {committee.summary}") - - if hasattr(committee, "owasp_repository") and committee.owasp_repository: - repo = committee.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if committee.name: - metadata_parts.append(f"Committee Name: {committee.name}") - - if committee.tags: - metadata_parts.append(f"Tags: {', '.join(committee.tags)}") - - if committee.topics: - metadata_parts.append(f"Topics: {', '.join(committee.topics)}") - - if committee.leaders_raw: - metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") - - if committee.related_urls: - valid_urls = [ - url - for url in committee.related_urls - if url and url not in (committee.invalid_urls or []) - ] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - metadata_parts.append(f"Active Committee: {'Yes' if committee.is_active else 'No'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py deleted file mode 100644 index d0dab81a0c..0000000000 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ /dev/null @@ -1,133 +0,0 @@ -"""A command to create chunks of OWASP event data for RAG.""" - -import os - -import openai -from django.core.management.base import BaseCommand - -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.owasp.models.event import Event - - -class Command(BaseCommand): - help = "Create chunks for OWASP event data" - - def add_arguments(self, parser): - parser.add_argument( - "--event", - type=str, - help="Process only the event with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the events", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of events to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if event := options["event"]: - queryset = Event.objects.filter(key=event) - elif options["all"]: - queryset = Event.objects.all() - else: - queryset = Event.upcoming_events() - - if not (total_events := queryset.count()): - self.stdout.write("No events found to process") - return - - self.stdout.write(f"Found {total_events} events to process") - - batch_size = options["batch_size"] - for offset in range(0, total_events, batch_size): - batch_events = queryset[offset : offset + batch_size] - - batch_chunks = [] - for event in batch_events: - batch_chunks.extend(self.handle_chunks(event)) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {len(batch_chunks)} chunks") - - self.stdout.write(f"Completed processing all {total_events} events") - - def handle_chunks(self, event: Event) -> list[Chunk]: - """Create chunks from an event's data.""" - prose_content, metadata_content = self.extract_event_content(event) - - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) - - if prose_content.strip(): - all_chunk_texts.extend(Chunk.split_text(prose_content)) - - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for event {event.key}") - return [] - - return create_chunks_and_embeddings( - all_chunk_texts, - content_object=event, - openai_client=self.openai_client, - ) - - def extract_event_content(self, event: Event) -> tuple[str, str]: - """Extract and separate prose content from metadata for an event. - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if event.description: - prose_parts.append(f"Description: {event.description}") - - if event.summary: - prose_parts.append(f"Summary: {event.summary}") - - if event.name: - metadata_parts.append(f"Event Name: {event.name}") - - if event.category: - metadata_parts.append(f"Category: {event.get_category_display()}") - - if event.start_date: - metadata_parts.append(f"Start Date: {event.start_date}") - - if event.end_date: - metadata_parts.append(f"End Date: {event.end_date}") - - if event.suggested_location: - metadata_parts.append(f"Location: {event.suggested_location}") - - if event.latitude and event.longitude: - metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") - - if event.url: - metadata_parts.append(f"Event URL: {event.url}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py deleted file mode 100644 index d472ea9589..0000000000 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ /dev/null @@ -1,178 +0,0 @@ -"""A command to create chunks of OWASP project data for RAG.""" - -import os - -import openai -from django.core.management.base import BaseCommand - -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.owasp.models.project import Project - - -class Command(BaseCommand): - help = "Create chunks for OWASP project data" - - def add_arguments(self, parser): - parser.add_argument( - "--project-key", type=str, help="Process only the project with this key" - ) - parser.add_argument("--all", action="store_true", help="Process all the projects") - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of projects to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if options["project_key"]: - queryset = Project.objects.filter(key=options["project_key"]) - elif options["all"]: - queryset = Project.objects.all() - else: - queryset = Project.objects.filter(is_active=True) - - if not (total_projects := queryset.count()): - self.stdout.write("No projects found to process") - return - - self.stdout.write(f"Found {total_projects} projects to process") - - batch_size = options["batch_size"] - for offset in range(0, total_projects, batch_size): - batch_projects = queryset[offset : offset + batch_size] - - batch_chunks = [] - for project in batch_projects: - chunks = self.create_chunks(project) - batch_chunks.extend(chunks) - - if batch_chunks: - chunks_count = len(batch_chunks) - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {chunks_count} chunks") - - self.stdout.write(f"Completed processing all {total_projects} projects") - - def create_chunks(self, project: Project) -> list[Chunk]: - prose_content, metadata_content = self.extract_project_content(project) - - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) - - if prose_content.strip(): - prose_chunks = Chunk.split_text(prose_content) - all_chunk_texts.extend(prose_chunks) - - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for project {project.key}") - return [] - - return create_chunks_and_embeddings( - all_chunk_texts=all_chunk_texts, - content_object=project, - openai_client=self.openai_client, - ) - - def extract_project_content(self, project: Project) -> tuple[str, str]: - prose_parts = [] - metadata_parts = [] - - if project.name: - metadata_parts.append(f"Project Name: {project.name}") - - if project.description: - prose_parts.append(f"Description: {project.description}") - - if project.summary: - prose_parts.append(f"Summary: {project.summary}") - - if project.level: - metadata_parts.append(f"Project Level: {project.level}") - - if project.type: - metadata_parts.append(f"Project Type: {project.type}") - - if hasattr(project, "owasp_repository") and project.owasp_repository: - repo = project.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if project.languages: - metadata_parts.append(f"Programming Languages: {', '.join(project.languages)}") - - if project.topics: - metadata_parts.append(f"Topics: {', '.join(project.topics)}") - - if project.licenses: - metadata_parts.append(f"Licenses: {', '.join(project.licenses)}") - - if project.tags: - metadata_parts.append(f"Tags: {', '.join(project.tags)}") - - if project.custom_tags: - metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") - - stats_parts = [] - if project.stars_count > 0: - stats_parts.append(f"Stars: {project.stars_count}") - if project.forks_count > 0: - stats_parts.append(f"Forks: {project.forks_count}") - if project.contributors_count > 0: - stats_parts.append(f"Contributors: {project.contributors_count}") - if project.releases_count > 0: - stats_parts.append(f"Releases: {project.releases_count}") - if project.open_issues_count > 0: - stats_parts.append(f"Open Issues: {project.open_issues_count}") - - if stats_parts: - metadata_parts.append("Project Statistics: " + ", ".join(stats_parts)) - - if project.leaders_raw: - metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") - - if project.related_urls: - valid_urls = [ - url - for url in project.related_urls - if url and url not in (project.invalid_urls or []) - ] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - if project.created_at: - metadata_parts.append(f"Created: {project.created_at.strftime('%Y-%m-%d')}") - - if project.updated_at: - metadata_parts.append(f"Last Updated: {project.updated_at.strftime('%Y-%m-%d')}") - - if project.released_at: - metadata_parts.append(f"Last Release: {project.released_at.strftime('%Y-%m-%d')}") - - if project.health_score is not None: - metadata_parts.append(f"Health Score: {project.health_score:.2f}") - - metadata_parts.append(f"Active Project: {'Yes' if project.is_active else 'No'}") - - metadata_parts.append( - f"Issue Tracking: {'Enabled' if project.track_issues else 'Disabled'}" - ) - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py deleted file mode 100644 index 30b20e0f39..0000000000 --- a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py +++ /dev/null @@ -1,56 +0,0 @@ -"""A command to create chunks of Slack messages.""" - -import os - -import openai -from django.core.management.base import BaseCommand - -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.slack.models.message import Message - - -class Command(BaseCommand): - help = "Create chunks for Slack messages" - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - total_messages = Message.objects.count() - self.stdout.write(f"Found {total_messages} messages to process") - - batch_size = 100 - for offset in range(0, total_messages, batch_size): - Chunk.bulk_save( - [ - chunk - for message in Message.objects.all()[offset : offset + batch_size] - for chunk in self.handle_chunks(message) - ] - ) - - self.stdout.write(f"Completed processing all {total_messages} messages") - - def handle_chunks(self, message: Message) -> list[Chunk]: - """Create chunks from a message.""" - if message.subtype in {"channel_join", "channel_leave"}: - return [] - - if not (chunk_text := Chunk.split_text(message.cleaned_text)): - self.stdout.write( - f"No chunks created for message {message.slack_message_id}: " - f"`{message.cleaned_text}`" - ) - return [] - - return create_chunks_and_embeddings( - all_chunk_texts=chunk_text, - content_object=message, - openai_client=self.openai_client, - ) diff --git a/backend/apps/ai/management/commands/ai_update_chapter_chunks.py b/backend/apps/ai/management/commands/ai_update_chapter_chunks.py new file mode 100644 index 0000000000..cee3d355c5 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_chapter_chunks.py @@ -0,0 +1,14 @@ +"""A command to create chunks of OWASP chapter data for RAG.""" + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.ai.common.extractors.chapter import extract_chapter_content +from apps.owasp.models.chapter import Chapter + + +class Command(BaseChunkCommand): + key_field_name = "key" + model_class = Chapter + + def extract_content(self, entity: Chapter) -> tuple[str, str]: + """Extract content from the chapter.""" + return extract_chapter_content(entity) diff --git a/backend/apps/ai/management/commands/ai_update_chapter_context.py b/backend/apps/ai/management/commands/ai_update_chapter_context.py new file mode 100644 index 0000000000..2c72f89c84 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_chapter_context.py @@ -0,0 +1,14 @@ +"""A command to update context for OWASP chapter data.""" + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.ai.common.extractors.chapter import extract_chapter_content +from apps.owasp.models.chapter import Chapter + + +class Command(BaseContextCommand): + key_field_name = "key" + model_class = Chapter + + def extract_content(self, entity: Chapter) -> tuple[str, str]: + """Extract content from the chapter.""" + return extract_chapter_content(entity) diff --git a/backend/apps/ai/management/commands/ai_update_committee_chunks.py b/backend/apps/ai/management/commands/ai_update_committee_chunks.py new file mode 100644 index 0000000000..611dba01fb --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_committee_chunks.py @@ -0,0 +1,14 @@ +"""A command to create chunks of OWASP committee data for RAG.""" + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.ai.common.extractors.committee import extract_committee_content +from apps.owasp.models.committee import Committee + + +class Command(BaseChunkCommand): + key_field_name = "key" + model_class = Committee + + def extract_content(self, entity: Committee) -> tuple[str, str]: + """Extract content from the committee.""" + return extract_committee_content(entity) diff --git a/backend/apps/ai/management/commands/ai_update_committee_context.py b/backend/apps/ai/management/commands/ai_update_committee_context.py new file mode 100644 index 0000000000..4b3bf29cda --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_committee_context.py @@ -0,0 +1,14 @@ +"""A command to update context for OWASP committee data.""" + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.ai.common.extractors.committee import extract_committee_content +from apps.owasp.models.committee import Committee + + +class Command(BaseContextCommand): + key_field_name = "key" + model_class = Committee + + def extract_content(self, entity: Committee) -> tuple[str, str]: + """Extract content from the committee.""" + return extract_committee_content(entity) diff --git a/backend/apps/ai/management/commands/ai_update_event_chunks.py b/backend/apps/ai/management/commands/ai_update_event_chunks.py new file mode 100644 index 0000000000..fa5bcbf5c4 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_event_chunks.py @@ -0,0 +1,24 @@ +"""A command to create chunks of OWASP event data for RAG.""" + +from django.db.models import QuerySet + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.ai.common.extractors.event import extract_event_content +from apps.owasp.models.event import Event + + +class Command(BaseChunkCommand): + key_field_name = "key" + model_class = Event + + def extract_content(self, entity: Event) -> tuple[str, str]: + """Extract content from the event.""" + return extract_event_content(entity) + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() + + def get_default_queryset(self) -> QuerySet: + """Return upcoming events by default instead of is_active filter.""" + return Event.upcoming_events() diff --git a/backend/apps/ai/management/commands/ai_update_event_context.py b/backend/apps/ai/management/commands/ai_update_event_context.py new file mode 100644 index 0000000000..15232a773a --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_event_context.py @@ -0,0 +1,24 @@ +"""A command to update context for OWASP event data.""" + +from django.db.models import QuerySet + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.ai.common.extractors.event import extract_event_content +from apps.owasp.models.event import Event + + +class Command(BaseContextCommand): + key_field_name = "key" + model_class = Event + + def extract_content(self, entity: Event) -> tuple[str, str]: + """Extract content from the event.""" + return extract_event_content(entity) + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() + + def get_default_queryset(self) -> QuerySet: + """Return upcoming events by default instead of is_active filter.""" + return Event.upcoming_events() diff --git a/backend/apps/ai/management/commands/ai_update_project_chunks.py b/backend/apps/ai/management/commands/ai_update_project_chunks.py new file mode 100644 index 0000000000..132e8fad0b --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_project_chunks.py @@ -0,0 +1,20 @@ +"""A command to create chunks of OWASP project data for RAG.""" + +from django.db.models import QuerySet + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.ai.common.extractors.project import extract_project_content +from apps.owasp.models.project import Project + + +class Command(BaseChunkCommand): + key_field_name = "key" + model_class = Project + + def extract_content(self, entity: Project) -> tuple[str, str]: + """Extract content from the project.""" + return extract_project_content(entity) + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() diff --git a/backend/apps/ai/management/commands/ai_update_project_context.py b/backend/apps/ai/management/commands/ai_update_project_context.py new file mode 100644 index 0000000000..d1aede6d98 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_project_context.py @@ -0,0 +1,20 @@ +"""A command to update context for OWASP project data.""" + +from django.db.models import QuerySet + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.ai.common.extractors.project import extract_project_content +from apps.owasp.models.project import Project + + +class Command(BaseContextCommand): + key_field_name = "key" + model_class = Project + + def extract_content(self, entity: Project) -> tuple[str, str]: + """Extract content from the project.""" + return extract_project_content(entity) + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() diff --git a/backend/apps/ai/management/commands/ai_update_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_update_slack_message_chunks.py new file mode 100644 index 0000000000..51985a2adf --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_slack_message_chunks.py @@ -0,0 +1,41 @@ +"""A command to create chunks of Slack messages.""" + +from django.db.models import QuerySet + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.slack.models.message import Message + + +class Command(BaseChunkCommand): + key_field_name = "slack_message_id" + model_class = Message + + def add_arguments(self, parser): + """Override to use different default batch size for messages.""" + parser.add_argument( + "--message-key", + type=str, + help="Process only the message with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the messages", + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", + ) + + def extract_content(self, entity: Message) -> tuple[str, str]: + """Extract content from the message.""" + return entity.cleaned_text or "", "" + + def get_default_queryset(self) -> QuerySet: + """Return all messages by default since Message model doesn't have is_active field.""" + return self.get_base_queryset() + + def source_name(self) -> str: + return "slack_message" diff --git a/backend/apps/ai/management/commands/ai_update_slack_message_context.py b/backend/apps/ai/management/commands/ai_update_slack_message_context.py new file mode 100644 index 0000000000..c89b253692 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_update_slack_message_context.py @@ -0,0 +1,42 @@ +"""A command to update context for Slack message data.""" + +from django.db.models import QuerySet + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.slack.models.message import Message + + +class Command(BaseContextCommand): + key_field_name = "slack_message_id" + model_class = Message + + def add_arguments(self, parser): + """Override to use different default batch size for messages.""" + super().add_arguments(parser) + parser.add_argument( + "--message-key", + type=str, + help="Process only the message with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the messages", + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", + ) + + def extract_content(self, entity: Message) -> tuple[str, str]: + """Extract content from the message.""" + return entity.cleaned_text or "", "" + + def get_default_queryset(self) -> QuerySet: + """Return all messages by default since Message model doesn't have is_active field.""" + return self.get_base_queryset() + + def source_name(self) -> str: + return "slack_message" diff --git a/backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py b/backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py new file mode 100644 index 0000000000..38cd76b854 --- /dev/null +++ b/backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py @@ -0,0 +1,70 @@ +# Generated by Django 5.2.4 on 2025-07-28 09:00 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0004_alter_chunk_unique_together_chunk_content_type_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.CreateModel( + name="Context", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("nest_created_at", models.DateTimeField(auto_now_add=True)), + ("nest_updated_at", models.DateTimeField(auto_now=True)), + ("generated_text", models.TextField(verbose_name="Generated Text")), + ("object_id", models.PositiveIntegerField(default=0)), + ("source", models.CharField(blank=True, default="", max_length=100)), + ( + "content_type", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="contenttypes.contenttype", + ), + ), + ], + options={ + "verbose_name": "Context", + "db_table": "ai_contexts", + }, + ), + migrations.AlterUniqueTogether( + name="chunk", + unique_together=set(), + ), + migrations.AddField( + model_name="chunk", + name="context", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="chunks", + to="ai.context", + ), + ), + migrations.AlterUniqueTogether( + name="chunk", + unique_together={("context", "text")}, + ), + migrations.RemoveField( + model_name="chunk", + name="content_type", + ), + migrations.RemoveField( + model_name="chunk", + name="object_id", + ), + ] diff --git a/backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py b/backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py new file mode 100644 index 0000000000..370700992d --- /dev/null +++ b/backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.4 on 2025-07-30 12:49 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0005_context_alter_chunk_unique_together_chunk_context_and_more"), + ] + + operations = [ + migrations.RenameField( + model_name="context", + old_name="generated_text", + new_name="content", + ), + migrations.AlterField( + model_name="chunk", + name="context", + field=models.ForeignKey( + default="", + on_delete=django.db.models.deletion.CASCADE, + related_name="chunks", + to="ai.context", + ), + ), + ] diff --git a/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py b/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py new file mode 100644 index 0000000000..da449203b4 --- /dev/null +++ b/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py @@ -0,0 +1,25 @@ +# Generated by Django 5.2.4 on 2025-07-30 18:47 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0006_rename_generated_text_context_content_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.AlterField( + model_name="chunk", + name="context", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="chunks", to="ai.context" + ), + ), + migrations.AlterUniqueTogether( + name="context", + unique_together={("content_type", "object_id")}, + ), + ] diff --git a/backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py b/backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py new file mode 100644 index 0000000000..2d807dd926 --- /dev/null +++ b/backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py @@ -0,0 +1,34 @@ +# Generated by Django 5.2.4 on 2025-08-10 19:06 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0007_alter_chunk_context_alter_context_unique_together"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="context", + unique_together=set(), + ), + migrations.AlterField( + model_name="context", + name="content_type", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="contenttypes.contenttype" + ), + ), + migrations.AlterField( + model_name="context", + name="object_id", + field=models.PositiveIntegerField(), + ), + migrations.AlterUniqueTogether( + name="context", + unique_together={("content_type", "object_id", "content")}, + ), + ] diff --git a/backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py b/backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py new file mode 100644 index 0000000000..e8f3b27aaa --- /dev/null +++ b/backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py @@ -0,0 +1,27 @@ +# Generated by Django 5.2.5 on 2025-08-15 08:40 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0008_alter_context_unique_together_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.RenameField( + model_name="context", + old_name="object_id", + new_name="entity_id", + ), + migrations.RenameField( + model_name="context", + old_name="content_type", + new_name="entity_type", + ), + migrations.AlterUniqueTogether( + name="context", + unique_together={("entity_type", "entity_id")}, + ), + ] diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index 8362948ffe..8dfcaf0022 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -1,11 +1,10 @@ """AI app chunk model.""" -from django.contrib.contenttypes.fields import GenericForeignKey -from django.contrib.contenttypes.models import ContentType from django.db import models from langchain.text_splitter import RecursiveCharacterTextSplitter from pgvector.django import VectorField +from apps.ai.models.context import Context from apps.common.models import BulkSaveModel, TimestampedModel from apps.common.utils import truncate @@ -16,25 +15,16 @@ class Chunk(TimestampedModel): class Meta: db_table = "ai_chunks" verbose_name = "Chunk" - unique_together = ("content_type", "object_id", "text") + unique_together = ("context", "text") - content_object = GenericForeignKey("content_type", "object_id") - content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, blank=True, null=True) + context = models.ForeignKey(Context, on_delete=models.CASCADE, related_name="chunks") embedding = VectorField(verbose_name="Embedding", dimensions=1536) - object_id = models.PositiveIntegerField(default=0) text = models.TextField(verbose_name="Text") def __str__(self): """Human readable representation.""" - content_name = ( - getattr(self.content_object, "name", None) - or getattr(self.content_object, "key", None) - or str(self.content_object) - ) - return ( - f"Chunk {self.id} for {self.content_type.model} {content_name}: " - f"{truncate(self.text, 50)}" - ) + context_str = str(self.context) if self.context else "No Context" + return f"Chunk {self.id} for {context_str}: {truncate(self.text, 50)}" @staticmethod def bulk_save(chunks, fields=None): @@ -54,8 +44,8 @@ def split_text(text: str) -> list[str]: @staticmethod def update_data( text: str, - content_object, embedding, + context: Context, *, save: bool = True, ) -> "Chunk | None": @@ -63,24 +53,22 @@ def update_data( Args: text (str): The text content of the chunk. - content_object: The object this chunk belongs to (Message, Chapter, etc.). embedding (list): The embedding vector for the chunk. + context (Context): The context this chunk belongs to. save (bool): Whether to save the chunk to the database. Returns: - Chunk: The updated chunk instance or None if it already exists. + Chunk: The created chunk instance. """ - content_type = ContentType.objects.get_for_model(content_object) - if Chunk.objects.filter( - content_type=content_type, object_id=content_object.id, text=text + context__entity_type=context.entity_type, + context__entity_id=context.entity_id, + text=text, ).exists(): return None - chunk = Chunk( - content_type=content_type, object_id=content_object.id, text=text, embedding=embedding - ) + chunk = Chunk(text=text, embedding=embedding, context=context) if save: chunk.save() diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py new file mode 100644 index 0000000000..e3d93bc277 --- /dev/null +++ b/backend/apps/ai/models/context.py @@ -0,0 +1,61 @@ +"""AI app context model.""" + +import logging + +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.db import models + +from apps.common.models import TimestampedModel +from apps.common.utils import truncate + +logger = logging.getLogger(__name__) + + +class Context(TimestampedModel): + """Context model for storing generated text related to OWASP entities.""" + + content = models.TextField(verbose_name="Generated Text") + entity_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + entity_id = models.PositiveIntegerField() + entity = GenericForeignKey("entity_type", "entity_id") + source = models.CharField(max_length=100, blank=True, default="") + + class Meta: + db_table = "ai_contexts" + verbose_name = "Context" + unique_together = ("entity_type", "entity_id") + + def __str__(self): + """Human readable representation.""" + entity = ( + getattr(self.entity, "name", None) + or getattr(self.entity, "key", None) + or str(self.entity) + ) + return f"{self.entity_type.model} {entity}: {truncate(self.content, 50)}" + + @staticmethod + def update_data( + content: str, + entity, + source: str = "", + *, + save: bool = True, + ) -> "Context": + """Create or update context for a given entity.""" + entity_type = ContentType.objects.get_for_model(entity) + entity_id = entity.pk + + try: + context = Context.objects.get(entity_type=entity_type, entity_id=entity_id) + except Context.DoesNotExist: + context = Context(entity_type=entity_type, entity_id=entity_id) + + context.content = content + context.source = source + + if save: + context.save() + + return context diff --git a/backend/tests/apps/ai/agent/__init__.py b/backend/tests/apps/ai/agent/__init__.py new file mode 100644 index 0000000000..60947f7d73 --- /dev/null +++ b/backend/tests/apps/ai/agent/__init__.py @@ -0,0 +1 @@ +"""AI agent tests package.""" diff --git a/backend/tests/apps/ai/agent/tools/__init__.py b/backend/tests/apps/ai/agent/tools/__init__.py new file mode 100644 index 0000000000..e60b458847 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/__init__.py @@ -0,0 +1 @@ +"""AI agent tools tests package.""" diff --git a/backend/tests/apps/ai/agent/tools/rag/__init__.py b/backend/tests/apps/ai/agent/tools/rag/__init__.py new file mode 100644 index 0000000000..9e0722de20 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/__init__.py @@ -0,0 +1 @@ +"""RAG tests package.""" diff --git a/backend/tests/apps/ai/agent/tools/rag/generator_test.py b/backend/tests/apps/ai/agent/tools/rag/generator_test.py new file mode 100644 index 0000000000..903bdf636d --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/generator_test.py @@ -0,0 +1,197 @@ +"""Tests for the RAG Generator.""" + +import os +from unittest.mock import MagicMock, patch + +import openai +import pytest + +from apps.ai.agent.tools.rag.generator import Generator + + +class TestGenerator: + """Test cases for the Generator class.""" + + def test_init_success(self): + """Test successful initialization with API key.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_openai.return_value = mock_client + + generator = Generator(chat_model="gpt-4") + + assert generator.chat_model == "gpt-4" + assert generator.openai_client == mock_client + mock_openai.assert_called_once_with(api_key="test-key") + + def test_init_no_api_key(self): + """Test initialization fails when API key is not set.""" + with ( + patch.dict(os.environ, {}, clear=True), + pytest.raises( + ValueError, + match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set", + ), + ): + Generator() + + def test_prepare_context_empty_chunks(self): + """Test context preparation with empty chunks list.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + generator = Generator() + + result = generator.prepare_context([]) + + assert result == "No context provided" + + def test_prepare_context_with_chunks(self): + """Test context preparation with valid chunks.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + generator = Generator() + + chunks = [ + {"source_name": "Chapter 1", "text": "This is chapter 1 content"}, + {"source_name": "Chapter 2", "text": "This is chapter 2 content"}, + ] + + result = generator.prepare_context(chunks) + + expected = ( + "Source Name: Chapter 1\nContent: This is chapter 1 content\n\n" + "---\n\n" + "Source Name: Chapter 2\nContent: This is chapter 2 content" + ) + assert result == expected + + def test_prepare_context_missing_fields(self): + """Test context preparation with chunks missing fields.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + generator = Generator() + + chunks = [ + {"text": "Content without source"}, + {"source_name": "Source without content"}, + {}, + ] + + result = generator.prepare_context(chunks) + + expected = ( + "Source Name: Unknown Source 1\nContent: Content without source\n\n" + "---\n\n" + "Source Name: Source without content\nContent: \n\n" + "---\n\n" + "Source Name: Unknown Source 3\nContent: " + ) + assert result == expected + + def test_generate_answer_success(self): + """Test successful answer generation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Generated answer" + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + generator = Generator() + + chunks = [{"source_name": "Test", "text": "Test content"}] + result = generator.generate_answer("What is OWASP?", chunks) + + assert result == "Generated answer" + mock_client.chat.completions.create.assert_called_once() + call_args = mock_client.chat.completions.create.call_args + assert call_args[1]["model"] == "gpt-4o" + assert len(call_args[1]["messages"]) == 2 + assert call_args[1]["messages"][0]["role"] == "system" + assert call_args[1]["messages"][1]["role"] == "user" + + def test_generate_answer_with_custom_model(self): + """Test answer generation with custom chat model.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Custom model answer" + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + generator = Generator(chat_model="gpt-3.5-turbo") + + chunks = [{"source_name": "Test", "text": "Test content"}] + result = generator.generate_answer("Test query", chunks) + + assert result == "Custom model answer" + call_args = mock_client.chat.completions.create.call_args + assert call_args[1]["model"] == "gpt-3.5-turbo" + + def test_generate_answer_openai_error(self): + """Test answer generation with OpenAI API error.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = openai.OpenAIError("API Error") + mock_openai.return_value = mock_client + + generator = Generator() + + chunks = [{"source_name": "Test", "text": "Test content"}] + result = generator.generate_answer("Test query", chunks) + + assert result == "I'm sorry, I'm currently unable to process your request." + + def test_generate_answer_with_empty_chunks(self): + """Test answer generation with empty chunks.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "No context answer" + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + generator = Generator() + + result = generator.generate_answer("Test query", []) + + assert result == "No context answer" + call_args = mock_client.chat.completions.create.call_args + assert "No context provided" in call_args[1]["messages"][1]["content"] + + def test_system_prompt_content(self): + """Test that system prompt contains expected content.""" + assert "OWASP Foundation" in Generator.SYSTEM_PROMPT + assert "context" in Generator.SYSTEM_PROMPT.lower() + assert "professional" in Generator.SYSTEM_PROMPT.lower() + assert "latitude and longitude" in Generator.SYSTEM_PROMPT.lower() + + def test_constants(self): + """Test class constants have expected values.""" + assert Generator.MAX_TOKENS == 2000 + assert isinstance(Generator.SYSTEM_PROMPT, str) + assert len(Generator.SYSTEM_PROMPT) > 0 diff --git a/backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py b/backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py new file mode 100644 index 0000000000..dcb24ca970 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py @@ -0,0 +1,179 @@ +"""Tests for the RAG Tool.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from apps.ai.agent.tools.rag.rag_tool import RagTool + + +class TestRagTool: + """Test cases for the RagTool class.""" + + def test_init_success(self): + """Test successful initialization of RagTool.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + rag_tool = RagTool( + embedding_model="custom-embedding-model", chat_model="custom-chat-model" + ) + + assert rag_tool.retriever == mock_retriever + assert rag_tool.generator == mock_generator + mock_retriever_class.assert_called_once_with(embedding_model="custom-embedding-model") + mock_generator_class.assert_called_once_with(chat_model="custom-chat-model") + + def test_init_default_models(self): + """Test initialization with default model parameters.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + RagTool() + + mock_retriever_class.assert_called_once_with(embedding_model="text-embedding-3-small") + mock_generator_class.assert_called_once_with(chat_model="gpt-4o") + + def test_init_no_api_key(self): + """Test initialization fails when API key is not set.""" + with ( + patch.dict(os.environ, {}, clear=True), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + ): + mock_retriever_class.side_effect = ValueError( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + + with pytest.raises( + ValueError, + match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set", + ): + RagTool() + + def test_query_success(self): + """Test successful query execution.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_chunks = [{"text": "Test content", "source_name": "Test Source"}] + mock_retriever.retrieve.return_value = mock_chunks + mock_generator.generate_answer.return_value = "Generated answer" + + rag_tool = RagTool() + + result = rag_tool.query( + question="What is OWASP?", + content_types=["chapter"], + limit=10, + similarity_threshold=0.5, + ) + + assert result == "Generated answer" + mock_retriever.retrieve.assert_called_once_with( + content_types=["chapter"], + limit=10, + query="What is OWASP?", + similarity_threshold=0.5, + ) + mock_generator.generate_answer.assert_called_once_with( + context_chunks=mock_chunks, query="What is OWASP?" + ) + + def test_query_with_defaults(self): + """Test query with default parameters.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_chunks = [] + mock_retriever.retrieve.return_value = mock_chunks + mock_generator.generate_answer.return_value = "Default answer" + + rag_tool = RagTool() + + result = rag_tool.query("Test question") + + assert result == "Default answer" + mock_retriever.retrieve.assert_called_once_with( + content_types=None, + limit=5, + query="Test question", + similarity_threshold=0.4, + ) + + def test_query_empty_content_types(self): + """Test query with empty content types list.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_chunks = [] + mock_retriever.retrieve.return_value = mock_chunks + mock_generator.generate_answer.return_value = "Answer" + + rag_tool = RagTool() + + result = rag_tool.query("Test question", content_types=[]) + + assert result == "Answer" + mock_retriever.retrieve.assert_called_once_with( + content_types=[], + limit=5, + query="Test question", + similarity_threshold=0.4, + ) + + @patch("apps.ai.agent.tools.rag.rag_tool.logger") + def test_query_logs_retrieval(self, mock_logger): + """Test that query logs the retrieval process.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_retriever.retrieve.return_value = [] + mock_generator.generate_answer.return_value = "Answer" + + rag_tool = RagTool() + rag_tool.query("Test question") + + mock_logger.info.assert_called_once_with("Retrieving context for query") diff --git a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py new file mode 100644 index 0000000000..c27115195c --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py @@ -0,0 +1,595 @@ +"""Tests for the RAG Retriever.""" + +import os +from unittest.mock import MagicMock, patch + +import openai +import pytest + +from apps.ai.agent.tools.rag.retriever import Retriever + + +class TestRetriever: + """Test cases for the Retriever class.""" + + def test_init_success(self): + """Test successful initialization with API key.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_openai.return_value = mock_client + + retriever = Retriever(embedding_model="custom-model") + + assert retriever.embedding_model == "custom-model" + assert retriever.openai_client == mock_client + mock_openai.assert_called_once_with(api_key="test-key") + + def test_init_no_api_key(self): + """Test initialization fails when API key is not set.""" + with ( + patch.dict(os.environ, {}, clear=True), + pytest.raises( + ValueError, + match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set", + ), + ): + Retriever() + + def test_init_default_model(self): + """Test initialization with default embedding model.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + assert retriever.embedding_model == "text-embedding-3-small" + + def test_get_query_embedding_success(self): + """Test successful query embedding generation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + retriever = Retriever() + result = retriever.get_query_embedding("test query") + + assert result == [0.1, 0.2, 0.3] + mock_client.embeddings.create.assert_called_once_with( + input=["test query"], model="text-embedding-3-small" + ) + + def test_get_query_embedding_openai_error(self): + """Test query embedding with OpenAI API error.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_client.embeddings.create.side_effect = openai.OpenAIError("API Error") + mock_openai.return_value = mock_client + + retriever = Retriever() + + with pytest.raises(openai.OpenAIError): + retriever.get_query_embedding("test query") + + def test_get_query_embedding_unexpected_error(self): + """Test query embedding with unexpected error.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_client.embeddings.create.side_effect = Exception("Unexpected error") + mock_openai.return_value = mock_client + + retriever = Retriever() + + with pytest.raises(Exception, match="Unexpected error"): + retriever.get_query_embedding("test query") + + def test_get_source_name_with_name(self): + """Test getting source name when object has name attribute.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + entity = MagicMock() + entity.name = "Test Name" + entity.title = "Test Title" + + result = retriever.get_source_name(entity) + assert result == "Test Name" + + def test_get_source_name_with_title(self): + """Test getting source name when object has title but no name.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + entity = MagicMock() + entity.name = None + entity.title = "Test Title" + entity.login = "test_login" + + result = retriever.get_source_name(entity) + assert result == "Test Title" + + def test_get_source_name_with_login(self): + """Test getting source name when object has login but no name/title.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + entity = MagicMock() + entity.name = None + entity.title = None + entity.login = "test_login" + entity.key = "test_key" + + result = retriever.get_source_name(entity) + assert result == "test_login" + + def test_get_source_name_fallback_to_str(self): + """Test getting source name falls back to string representation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + entity = MagicMock() + entity.name = None + entity.title = None + entity.login = None + entity.key = None + entity.summary = None + entity.__str__ = MagicMock(return_value="String representation") + + result = retriever.get_source_name(entity) + assert result == "String representation" + + def test_get_additional_context_chapter(self): + """Test getting additional context for chapter content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.__class__.__name__ = "Chapter" + content_object.suggested_location = "New York" + content_object.region = "North America" + content_object.country = "USA" + content_object.postal_code = "10001" + content_object.currency = "USD" + content_object.meetup_group = "OWASP NYC" + content_object.tags = ["security", "web"] + content_object.topics = ["OWASP Top 10"] + content_object.leaders_raw = ["John Doe", "Jane Smith"] + content_object.related_urls = ["https://example.com"] + content_object.is_active = True + content_object.url = "https://owasp.org/chapter" + + result = retriever.get_additional_context(content_object) + + expected_keys = [ + "location", + "region", + "country", + "postal_code", + "currency", + "meetup_group", + "tags", + "topics", + "leaders", + "related_urls", + "is_active", + "url", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_project(self): + """Test getting additional context for project content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.__class__.__name__ = "Project" + content_object.level = "flagship" + content_object.type = "tool" + content_object.languages = ["python"] + content_object.topics = ["security"] + content_object.licenses = ["MIT"] + content_object.tags = ["web"] + content_object.custom_tags = ["api"] + content_object.stars_count = 100 + content_object.forks_count = 20 + content_object.contributors_count = 5 + content_object.releases_count = 3 + content_object.open_issues_count = 2 + content_object.leaders_raw = ["Alice"] + content_object.related_urls = ["https://project.example.com"] + content_object.created_at = "2023-01-01" + content_object.updated_at = "2023-01-02" + content_object.released_at = "2023-01-03" + content_object.health_score = 85.5 + content_object.is_active = True + content_object.track_issues = True + content_object.url = "https://owasp.org/project" + + result = retriever.get_additional_context(content_object) + + expected_keys = [ + "level", + "project_type", + "languages", + "topics", + "licenses", + "tags", + "custom_tags", + "stars_count", + "forks_count", + "contributors_count", + "releases_count", + "open_issues_count", + "leaders", + "related_urls", + "created_at", + "updated_at", + "released_at", + "health_score", + "is_active", + "track_issues", + "url", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_event(self): + """Test getting additional context for event content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.__class__.__name__ = "Event" + content_object.start_date = "2023-01-01" + content_object.end_date = "2023-01-02" + content_object.suggested_location = "San Francisco" + content_object.category = "conference" + content_object.latitude = 37.7749 + content_object.longitude = -122.4194 + content_object.url = "https://event.example.com" + content_object.description = "Test event description" + content_object.summary = "Test event summary" + + result = retriever.get_additional_context(content_object) + + expected_keys = [ + "start_date", + "end_date", + "location", + "category", + "latitude", + "longitude", + "url", + "description", + "summary", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_committee(self): + """Test getting additional context for committee content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.__class__.__name__ = "Committee" + content_object.is_active = True + content_object.leaders = ["John Doe", "Jane Smith"] + content_object.url = "https://committee.example.com" + content_object.description = "Test committee description" + content_object.summary = "Test committee summary" + content_object.tags = ["security", "governance"] + content_object.topics = ["policy", "standards"] + content_object.related_urls = ["https://related.example.com"] + + result = retriever.get_additional_context(content_object) + + expected_keys = [ + "is_active", + "leaders", + "url", + "description", + "summary", + "tags", + "topics", + "related_urls", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_message(self): + """Test getting additional context for message content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + conversation = MagicMock() + conversation.slack_channel_id = "C1234567890" + + parent_message = MagicMock() + parent_message.ts = "1234567890.123456" + + author = MagicMock() + author.name = "testuser" + + content_object = MagicMock() + content_object.__class__.__name__ = "Message" + content_object.conversation = conversation + content_object.parent_message = parent_message + content_object.ts = "1234567891.123456" + content_object.author = author + + result = retriever.get_additional_context(content_object) + + expected_keys = ["channel", "thread_ts", "ts", "user"] + for key in expected_keys: + assert key in result + + assert result["channel"] == "C1234567890" + assert result["thread_ts"] == "1234567890.123456" + assert result["ts"] == "1234567891.123456" + assert result["user"] == "testuser" + + def test_get_additional_context_message_no_conversation(self): + """Test getting additional context for message with no conversation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.__class__.__name__ = "Message" + content_object.conversation = None + content_object.parent_message = None + content_object.ts = "1234567891.123456" + content_object.author = None + + result = retriever.get_additional_context(content_object) + + assert result["ts"] == "1234567891.123456" + + def test_extract_content_types_from_query_single_type(self): + """Test extracting single content type from query.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("Tell me about chapters") + assert result == ["chapter"] + + def test_extract_content_types_from_query_multiple_types(self): + """Test extracting multiple content types from query.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("Show me events and projects") + assert set(result) == {"event", "project"} + + def test_extract_content_types_from_query_plural_forms(self): + """Test extracting content types from query with plural forms.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("List all committees") + assert result == ["committee"] + + def test_extract_content_types_from_query_no_matches(self): + """Test extracting content types when no matches found.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("Random query with no keywords") + assert result == [] + + def test_supported_content_types(self): + """Test that supported content types are defined correctly.""" + assert set(Retriever.SUPPORTED_ENTITY_TYPES) == { + "chapter", + "committee", + "event", + "message", + "project", + } + + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_with_app_label_content_types(self, mock_chunk): + """Test retrieve method with app_label.model content types filter.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + mock_final = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.filter.return_value = mock_final + mock_final.select_related.return_value = mock_final + mock_final.order_by.return_value = mock_final + mock_final.__getitem__ = MagicMock(return_value=[]) + + retriever = Retriever() + result = retriever.retrieve("test query", content_types=["owasp.chapter"]) + + assert result == [] + mock_chunk.objects.annotate.assert_called_once() + mock_annotated.filter.assert_called_once() + mock_filtered.filter.assert_called_once() + + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_successful_with_chunks(self, mock_chunk): + """Test retrieve method with successful chunk retrieval.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_content_object = MagicMock() + mock_content_object.name = "Test Chapter" + mock_content_object.__class__.__name__ = "Chapter" + mock_content_object.suggested_location = "New York" + + mock_entity_type = MagicMock() + mock_entity_type.model = "chapter" + + mock_context = MagicMock() + mock_context.entity = mock_content_object + mock_context.entity_type = mock_entity_type + mock_context.entity_id = "123" + + mock_chunk_instance = MagicMock() + mock_chunk_instance.id = 1 + mock_chunk_instance.text = "Test chunk text" + mock_chunk_instance.similarity = 0.85 + mock_chunk_instance.context = mock_context + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.select_related.return_value = mock_filtered + mock_filtered.order_by.return_value = mock_filtered + mock_filtered.__getitem__ = MagicMock(return_value=[mock_chunk_instance]) + + retriever = Retriever() + result = retriever.retrieve("test query") + + assert len(result) == 1 + assert result[0]["text"] == "Test chunk text" + assert result[0]["source_type"] == "chapter" + assert result[0]["source_name"] == "Test Chapter" + assert result[0]["source_id"] == "123" + assert "additional_context" in result[0] + + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_with_content_types_filter(self, mock_chunk): + """Test retrieve method with content types filter.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + mock_final = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.filter.return_value = mock_final + mock_final.select_related.return_value = mock_final + mock_final.order_by.return_value = mock_final + mock_final.__getitem__ = MagicMock(return_value=[]) + + retriever = Retriever() + result = retriever.retrieve("test query", content_types=["chapter"]) + + assert result == [] + mock_chunk.objects.annotate.assert_called_once() + mock_annotated.filter.assert_called_once() + mock_filtered.filter.assert_called_once() + + @patch("apps.ai.agent.tools.rag.retriever.logger") + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_with_none_content_object(self, mock_chunk, mock_logger): + """Test retrieve method when content object is None.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_chunk_instance = MagicMock() + mock_chunk_instance.id = 1 + mock_chunk_instance.context = None + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.select_related.return_value = mock_filtered + mock_filtered.order_by.return_value = mock_filtered + mock_filtered.__getitem__ = MagicMock(return_value=[mock_chunk_instance]) + + retriever = Retriever() + result = retriever.retrieve("test query") + + assert result == [] + mock_logger.warning.assert_called_once_with( + "Content object is None for chunk %s. Skipping.", 1 + ) diff --git a/backend/tests/apps/ai/common/base/__init__.py b/backend/tests/apps/ai/common/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/apps/ai/common/base/ai_command_test.py b/backend/tests/apps/ai/common/base/ai_command_test.py new file mode 100644 index 0000000000..83d73dfc97 --- /dev/null +++ b/backend/tests/apps/ai/common/base/ai_command_test.py @@ -0,0 +1,269 @@ +import os +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.common.base.ai_command import BaseAICommand + + +class MockTestModel: + """Mock model for testing.""" + + objects = Mock() + pk = 1 + __name__ = "TestEntity" + + +@pytest.fixture +def command(): + """Fixture for ConcreteAICommand instance.""" + cmd = ConcreteAICommand() + cmd.entity_name = "test_entity" + cmd.entity_name_plural = "test_entities" + return cmd + + +@pytest.fixture +def mock_entity(): + """Fixture for mock entity with test_key attribute.""" + entity = Mock() + entity.test_key = "test-key-123" + entity.pk = 42 + return entity + + +@pytest.fixture +def mock_queryset(): + """Fixture for mock queryset.""" + return Mock() + + +class ConcreteAICommand(BaseAICommand): + """Concrete implementation of BaseAICommand for testing.""" + + model_class = MockTestModel + entity_name = "test_entity" + entity_name_plural = "test_entities" + key_field_name = "test_key" + + def extract_content(self, entity): + return ("prose content", "metadata content") + + +class TestBaseAICommand: + """Test suite for the BaseAICommand class.""" + + def test_command_inheritance(self, command): + assert isinstance(command, BaseCommand) + + def test_initialization(self, command): + assert command.openai_client is None + + def test_abstract_attributes_implemented(self, command): + assert command.model_class == MockTestModel + assert command.entity_name == "test_entity" + assert command.entity_name_plural == "test_entities" + assert command.key_field_name == "test_key" + + mock_entity = Mock() + result = command.extract_content(mock_entity) + assert result == ("prose content", "metadata content") + + def test_source_name_default(self, command): + result = command.source_name() + assert result == "owasp_test_entity" + + def test_get_base_queryset(self, command): + with patch.object(MockTestModel, "objects") as mock_objects: + mock_manager = Mock() + mock_objects.all.return_value = mock_manager + mock_objects.return_value = mock_manager + + command.get_base_queryset() + mock_objects.all.assert_called_once() + + def test_get_default_queryset(self, command): + with patch.object(command, "get_base_queryset") as mock_base_qs: + mock_queryset = Mock() + mock_filtered_qs = Mock() + mock_queryset.filter.return_value = mock_filtered_qs + mock_base_qs.return_value = mock_queryset + + result = command.get_default_queryset() + + mock_base_qs.assert_called_once() + mock_queryset.filter.assert_called_once_with(is_active=True) + assert result == mock_filtered_qs + + def test_add_common_arguments(self, command): + parser = Mock() + + command.add_common_arguments(parser) + + assert parser.add_argument.call_count == 3 + + calls = parser.add_argument.call_args_list + + assert calls[0][0] == ("--test_entity-key",) + assert calls[0][1]["type"] is str + assert "Process only the test_entity with this key" in calls[0][1]["help"] + + assert calls[1][0] == ("--all",) + assert calls[1][1]["action"] == "store_true" + assert "Process all the test_entities" in calls[1][1]["help"] + + assert calls[2][0] == ("--batch-size",) + assert calls[2][1]["type"] is int + assert calls[2][1]["default"] == 50 + assert "Number of test_entities to process in each batch" in calls[2][1]["help"] + + def test_add_arguments_calls_common(self, command): + parser = Mock() + + with patch.object(command, "add_common_arguments") as mock_add_common: + command.add_arguments(parser) + mock_add_common.assert_called_once_with(parser) + + def test_get_queryset_with_entity_key(self, command): + options = {"test_entity_key": "test-key-123"} + + with patch.object(command, "get_base_queryset") as mock_base_qs: + mock_queryset = Mock() + mock_filtered_qs = Mock() + mock_queryset.filter.return_value = mock_filtered_qs + mock_base_qs.return_value = mock_queryset + + result = command.get_queryset(options) + + mock_base_qs.assert_called_once() + mock_queryset.filter.assert_called_once_with(test_key="test-key-123") + assert result == mock_filtered_qs + + def test_get_queryset_with_all_option(self, command): + options = {"all": True} + + with patch.object(command, "get_base_queryset") as mock_base_qs: + mock_queryset = Mock() + mock_base_qs.return_value = mock_queryset + + result = command.get_queryset(options) + + mock_base_qs.assert_called_once() + assert result == mock_queryset + + def test_get_queryset_default(self, command): + options = {} + + with patch.object(command, "get_default_queryset") as mock_default_qs: + mock_queryset = Mock() + mock_default_qs.return_value = mock_queryset + + result = command.get_queryset(options) + + mock_default_qs.assert_called_once() + assert result == mock_queryset + + def test_get_entity_key_with_key_field(self, command, mock_entity): + result = command.get_entity_key(mock_entity) + assert result == "test-key-123" + + def test_get_entity_key_fallback_to_pk(self, command): + entity = Mock() + entity.pk = 42 + if hasattr(entity, "test_key"): + delattr(entity, "test_key") + + result = command.get_entity_key(entity) + assert result == "42" + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-api-key"}) + @patch("apps.ai.common.base.ai_command.openai.OpenAI") + def test_setup_openai_client_success(self, mock_openai_class, command): + mock_client = Mock() + mock_openai_class.return_value = mock_client + + result = command.setup_openai_client() + + assert result is True + assert command.openai_client == mock_client + mock_openai_class.assert_called_once_with(api_key="test-api-key") + + @patch.dict(os.environ, {}, clear=True) + def test_setup_openai_client_no_api_key(self, command): + if "DJANGO_OPEN_AI_SECRET_KEY" in os.environ: + del os.environ["DJANGO_OPEN_AI_SECRET_KEY"] + + with patch.object(command.stdout, "write") as mock_write: + result = command.setup_openai_client() + + assert result is False + assert command.openai_client is None + mock_write.assert_called_once() + call_args = mock_write.call_args[0][0] + assert "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" in str(call_args) + + def test_handle_batch_processing_empty_queryset(self, command, mock_queryset): + mock_queryset.count.return_value = 0 + process_batch_func = Mock() + + with patch.object(command.stdout, "write") as mock_write: + command.handle_batch_processing(mock_queryset, 10, process_batch_func) + + mock_write.assert_called_once_with("No test_entities found to process") + process_batch_func.assert_not_called() + + def test_handle_batch_processing_with_data(self, command): + mock_entities = [Mock() for _ in range(15)] + + mock_queryset = Mock() + mock_queryset.count.return_value = 15 + + def mock_getitem(slice_obj): + start = slice_obj.start or 0 + stop = slice_obj.stop + return mock_entities[start:stop] + + mock_queryset.__getitem__ = Mock(side_effect=mock_getitem) + + process_batch_func = Mock(side_effect=[5, 5, 5]) + + with patch.object(command.stdout, "write") as mock_write: + command.handle_batch_processing(mock_queryset, 5, process_batch_func) + + assert process_batch_func.call_count == 3 + + calls = process_batch_func.call_args_list + assert len(calls[0][0][0]) == 5 + assert len(calls[1][0][0]) == 5 + assert len(calls[2][0][0]) == 5 + + write_calls = mock_write.call_args_list + assert len(write_calls) == 2 + assert "Found 15 test_entities to process" in str(write_calls[0]) + assert "Completed processing 15/15 test_entities" in str(write_calls[1]) + + def test_handle_batch_processing_partial_processing(self, command): + mock_entities = [Mock() for _ in range(10)] + + mock_queryset = Mock() + mock_queryset.count.return_value = 10 + + def mock_getitem(slice_obj): + start = slice_obj.start or 0 + stop = slice_obj.stop + return mock_entities[start:stop] + + mock_queryset.__getitem__ = Mock(side_effect=mock_getitem) + + process_batch_func = Mock(side_effect=[3, 2]) + + with patch.object(command.stdout, "write") as mock_write: + command.handle_batch_processing(mock_queryset, 5, process_batch_func) + + assert process_batch_func.call_count == 2 + + write_calls = mock_write.call_args_list + assert len(write_calls) == 2 + assert "Found 10 test_entities to process" in str(write_calls[0]) + assert "Completed processing 5/10 test_entities" in str(write_calls[1]) diff --git a/backend/tests/apps/ai/common/base/chunk_command_test.py b/backend/tests/apps/ai/common/base/chunk_command_test.py new file mode 100644 index 0000000000..627acf36f0 --- /dev/null +++ b/backend/tests/apps/ai/common/base/chunk_command_test.py @@ -0,0 +1,436 @@ +"""Tests for the BaseChunkCommand class.""" + +from typing import Any +from unittest.mock import Mock, call, patch + +import pytest +from django.contrib.contenttypes.models import ContentType +from django.core.management.base import BaseCommand + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context + + +class ConcreteChunkCommand(BaseChunkCommand): + """Concrete implementation of BaseChunkCommand for testing.""" + + model_class: type[Any] = Mock # type: ignore[assignment] + entity_name = "test_entity" + entity_name_plural = "test_entities" + key_field_name = "test_key" + + def extract_content(self, entity): + """Extract content from entity.""" + return ("prose content", "metadata content") + + +@pytest.fixture +def command(): + """Return a concrete chunk command instance for testing.""" + cmd = ConcreteChunkCommand() + mock_model = Mock() + mock_model.__name__ = "TestEntity" + cmd.model_class = mock_model + cmd.entity_name = "test_entity" + cmd.entity_name_plural = "test_entities" + return cmd + + +@pytest.fixture +def mock_entity(): + """Return a mock entity instance.""" + entity = Mock() + entity.id = 1 + entity.test_key = "test-key-123" + entity.is_active = True + return entity + + +@pytest.fixture +def mock_context(): + """Return a mock context instance.""" + context = Mock(spec=Context) + context.id = 1 + context.content_type_id = 1 + context.object_id = 1 + context.chunks.aggregate.return_value = {"latest_created": None} + context.chunks.all.return_value.delete.return_value = (0, {}) + context.nest_updated_at = Mock() + return context + + +@pytest.fixture +def mock_content_type(): + """Return a mock content type.""" + content_type = Mock(spec=ContentType) + content_type.id = 1 + return content_type + + +@pytest.fixture +def mock_chunks(): + """Return a list of mock chunk instances.""" + chunks = [] + for i in range(3): + chunk = Mock(spec=Chunk) + chunk.id = i + 1 + chunk.text = f"Chunk text {i + 1}" + chunk.context_id = 1 + chunks.append(chunk) + return chunks + + +class TestBaseChunkCommand: + """Test suite for the BaseChunkCommand class.""" + + def test_command_inheritance(self, command): + """Test that BaseChunkCommand inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_help_method(self, command): + """Test the help method returns appropriate help text.""" + expected_help = "Create or update chunks for OWASP test_entity data" + assert command.help() == expected_help + + def test_abstract_methods_implemented(self, command): + """Test that all abstract methods are properly implemented.""" + assert command.model_class.__name__ == "TestEntity" + assert command.entity_name == "test_entity" + assert command.entity_name_plural == "test_entities" + assert command.key_field_name == "test_key" + + mock_entity = Mock() + result = command.extract_content(mock_entity) + assert result == ("prose content", "metadata content") + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + def test_process_chunks_batch_no_context( + self, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_content_type, + ): + """Test process_chunks_batch when no context is found.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = None + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once() + warning_call = mock_write.call_args[0][0] + assert "No context found for test_entity test-key-123" in str(warning_call) + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + def test_process_chunks_batch_empty_content( + self, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + ): + """Test process_chunks_batch when extracted content is empty.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + + with ( + patch.object(command, "extract_content", return_value=("", "")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + # Check that it wrote the initial message and the empty content message + expected_calls = [ + call("Context for test-key-123 requires chunk creation/update"), + call("No content to chunk for test_entity test-key-123"), + ] + mock_write.assert_has_calls(expected_calls) + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + def test_process_chunks_batch_no_chunks_created( + self, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + ): + """Test process_chunks_batch when no chunks are created from text.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = [] + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + # Check that both messages were written + expected_calls = [ + call("Context for test-key-123 requires chunk creation/update"), + call("No chunks created for test_entity test-key-123"), + ] + mock_write.assert_has_calls(expected_calls) + call_args = mock_write.call_args[0][0] + assert "No chunks created for test_entity test-key-123" in call_args + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + @patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings") + @patch("apps.ai.models.chunk.Chunk.bulk_save") + def test_process_chunks_batch_success( + self, + mock_bulk_save, + mock_create_chunks, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + mock_chunks, + ): + """Test successful chunk processing.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1", "chunk2", "chunk3"] + mock_create_chunks.return_value = mock_chunks + command.openai_client = Mock() + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_chunks_batch([mock_entity]) + + assert result == 1 + mock_create_chunks.assert_called_once_with( + chunk_texts=["chunk1", "chunk2", "chunk3"], + context=mock_context, + openai_client=command.openai_client, + save=False, + ) + mock_bulk_save.assert_called_once_with(mock_chunks) + mock_write.assert_has_calls( + [ + call("Context for test-key-123 requires chunk creation/update"), + call(command.style.SUCCESS("Created 3 new chunks for test-key-123")), + ] + ) + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + @patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings") + @patch("apps.ai.models.chunk.Chunk.bulk_save") + def test_process_chunks_batch_multiple_entities( + self, + mock_bulk_save, + mock_create_chunks, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_context, + mock_content_type, + mock_chunks, + ): + """Test processing multiple entities in a batch.""" + entities = [] + for i in range(3): + entity = Mock() + entity.id = i + 1 + entity.test_key = f"test-key-{i + 1}" + entity.is_active = True + entities.append(entity) + + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1", "chunk2"] + mock_create_chunks.return_value = mock_chunks[:2] + command.openai_client = Mock() + + with patch.object(command.stdout, "write"): + result = command.process_chunks_batch(entities) + + assert result == 3 + assert mock_create_chunks.call_count == 3 + mock_bulk_save.assert_called_once() + bulk_save_args = mock_bulk_save.call_args[0][0] + assert len(bulk_save_args) == 6 + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + @patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings") + def test_process_chunks_batch_create_chunks_fails( + self, + mock_create_chunks, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + ): + """Test process_chunks_batch when create_chunks_and_embeddings fails.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1", "chunk2"] + mock_create_chunks.return_value = None + command.openai_client = Mock() + + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_create_chunks.assert_called_once() + + def test_process_chunks_batch_content_combination( + self, command, mock_entity, mock_context, mock_content_type + ): + """Test that metadata and prose content are properly combined.""" + with ( + patch( + "apps.ai.common.base.chunk_command.ContentType.objects.get_for_model" + ) as mock_get_content_type, + patch( + "apps.ai.common.base.chunk_command.Context.objects.filter" + ) as mock_context_filter, + patch("apps.ai.models.chunk.Chunk.split_text") as mock_split_text, + patch( + "apps.ai.common.base.chunk_command.create_chunks_and_embeddings" + ) as mock_create_chunks, + patch("apps.ai.models.chunk.Chunk.bulk_save"), + ): + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1"] + mock_create_chunks.return_value = [Mock()] + command.openai_client = Mock() + + with patch.object( + command, + "extract_content", + return_value=("prose", "metadata"), + ): + command.process_chunks_batch([mock_entity]) + + expected_content = "metadata\n\nprose" + mock_split_text.assert_called_once_with(expected_content) + + mock_split_text.reset_mock() + with patch.object(command, "extract_content", return_value=("prose", "")): + command.process_chunks_batch([mock_entity]) + + mock_split_text.assert_called_with("prose") + + @patch.object(BaseChunkCommand, "setup_openai_client") + @patch.object(BaseChunkCommand, "get_queryset") + @patch.object(BaseChunkCommand, "handle_batch_processing") + def test_handle_method_success( + self, mock_handle_batch, mock_get_queryset, mock_setup_client, command + ): + """Test the handle method with successful setup.""" + mock_setup_client.return_value = True + mock_queryset = Mock() + mock_get_queryset.return_value = mock_queryset + options = {"batch_size": 10} + + command.handle(**options) + + mock_setup_client.assert_called_once() + mock_get_queryset.assert_called_once_with(options) + mock_handle_batch.assert_called_once_with( + queryset=mock_queryset, + batch_size=10, + process_batch_func=command.process_chunks_batch, + ) + + @patch.object(BaseChunkCommand, "setup_openai_client") + def test_handle_method_openai_setup_fails(self, mock_setup_client, command): + """Test the handle method when OpenAI client setup fails.""" + mock_setup_client.return_value = False + options = {"batch_size": 10} + + with ( + patch.object(command, "get_queryset") as mock_get_queryset, + patch.object(command, "handle_batch_processing") as mock_handle_batch, + ): + command.handle(**options) + + mock_setup_client.assert_called_once() + mock_get_queryset.assert_not_called() + mock_handle_batch.assert_not_called() + + def test_process_chunks_batch_metadata_only_content( + self, command, mock_entity, mock_context, mock_content_type + ): + """Test process_chunks_batch with only metadata content.""" + with ( + patch( + "apps.ai.common.base.chunk_command.ContentType.objects.get_for_model" + ) as mock_get_content_type, + patch( + "apps.ai.common.base.chunk_command.Context.objects.filter" + ) as mock_context_filter, + patch("apps.ai.models.chunk.Chunk.split_text") as mock_split_text, + patch( + "apps.ai.common.base.chunk_command.create_chunks_and_embeddings" + ) as mock_create_chunks, + patch("apps.ai.models.chunk.Chunk.bulk_save") as mock_bulk_save, + ): + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1"] + mock_create_chunks.return_value = [Mock()] + command.openai_client = Mock() + + with patch.object( + command, + "extract_content", + return_value=("", "metadata"), + ): + command.process_chunks_batch([mock_entity]) + + mock_split_text.assert_called_once_with("metadata\n\n") + mock_bulk_save.assert_called_once() + + def test_process_chunks_batch_whitespace_only_content( + self, command, mock_entity, mock_context, mock_content_type + ): + """Test process_chunks_batch with whitespace-only content.""" + with ( + patch( + "apps.ai.common.base.chunk_command.ContentType.objects.get_for_model" + ) as mock_get_content_type, + patch( + "apps.ai.common.base.chunk_command.Context.objects.filter" + ) as mock_context_filter, + ): + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + + with ( + patch.object(command, "extract_content", return_value=(" \n\t ", " \t\n ")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + expected_calls = [ + call("Context for test-key-123 requires chunk creation/update"), + call("No content to chunk for test_entity test-key-123"), + ] + mock_write.assert_has_calls(expected_calls) diff --git a/backend/tests/apps/ai/common/base/context_command_test.py b/backend/tests/apps/ai/common/base/context_command_test.py new file mode 100644 index 0000000000..ff1d6cf9c5 --- /dev/null +++ b/backend/tests/apps/ai/common/base/context_command_test.py @@ -0,0 +1,305 @@ +"""Tests for the BaseContextCommand class.""" + +from typing import Any +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.ai.models.context import Context + + +class ConcreteContextCommand(BaseContextCommand): + """Concrete implementation of BaseContextCommand for testing.""" + + model_class: type[Any] = Mock # type: ignore[assignment] + entity_name = "test_entity" + entity_name_plural = "test_entities" + key_field_name = "test_key" + + def extract_content(self, entity): + return ("prose content", "metadata content") + + +@pytest.fixture +def command(): + """Return a concrete context command instance for testing.""" + cmd = ConcreteContextCommand() + mock_model = Mock() + mock_model.__name__ = "TestEntity" + cmd.model_class = mock_model + cmd.entity_name = "test_entity" + cmd.entity_name_plural = "test_entities" + return cmd + + +@pytest.fixture +def mock_entity(): + """Return a mock entity instance.""" + entity = Mock() + entity.id = 1 + entity.test_key = "test-key-123" + entity.is_active = True + return entity + + +@pytest.fixture +def mock_context(): + """Return a mock context instance.""" + context = Mock(spec=Context) + context.id = 1 + context.content = "test content" + context.content_type_id = 1 + context.object_id = 1 + return context + + +class TestBaseContextCommand: + """Test suite for the BaseContextCommand class.""" + + def test_command_inheritance(self, command): + """Test that BaseContextCommand inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_help_method(self, command): + """Test the help method returns appropriate help text.""" + expected_help = "Update context for OWASP test_entity data" + assert command.help() == expected_help + + def test_abstract_methods_implemented(self, command): + """Test that all abstract methods are properly implemented.""" + assert command.model_class.__name__ == "TestEntity" + assert command.entity_name == "test_entity" + assert command.entity_name_plural == "test_entities" + assert command.key_field_name == "test_key" + + mock_entity = Mock() + result = command.extract_content(mock_entity) + assert result == ("prose content", "metadata content") + + def test_process_context_batch_empty_content(self, command, mock_entity): + """Test process_context_batch when extracted content is empty.""" + with ( + patch.object(command, "extract_content", return_value=("", "")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_context_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once_with("No content for test_entity test-key-123") + + def test_process_context_batch_whitespace_only_content(self, command, mock_entity): + """Test process_context_batch with whitespace-only content.""" + with ( + patch.object(command, "extract_content", return_value=(" \n\t ", " \t\n ")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_context_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once_with("No content for test_entity test-key-123") + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_success( + self, mock_context_class, command, mock_entity, mock_context + ): + """Test successful context processing.""" + mock_context_class.update_data.return_value = mock_context + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_entity]) + + assert result == 1 + mock_context_class.update_data.assert_called_once_with( + content="metadata content\n\nprose content", + entity=mock_entity, + source="owasp_test_entity", + ) + mock_write.assert_called_once_with("Created context for test-key-123") + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_creation_fails(self, mock_context_class, command, mock_entity): + """Test process_context_batch when context creation fails.""" + mock_context_class.update_data.return_value = None + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_entity]) + + assert result == 0 + mock_context_class.update_data.assert_called_once() + mock_write.assert_called_once() + call_args = mock_write.call_args[0][0] + assert "Failed to create context for test-key-123" in str(call_args) + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_multiple_entities( + self, mock_context_class, command, mock_context + ): + """Test processing multiple entities in a batch.""" + entities = [] + for i in range(3): + entity = Mock() + entity.id = i + 1 + entity.test_key = f"test-key-{i + 1}" + entity.is_active = True + entities.append(entity) + + mock_context_class.update_data.return_value = mock_context + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch(entities) + + assert result == 3 + assert mock_context_class.update_data.call_count == 3 + assert mock_write.call_count == 3 + + calls = mock_context_class.update_data.call_args_list + for i, call in enumerate(calls): + _, kwargs = call + assert kwargs["entity"] == entities[i] + assert kwargs["content"] == "metadata content\n\nprose content" + assert kwargs["source"] == "owasp_test_entity" + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_mixed_success_failure( + self, mock_context_class, command, mock_context + ): + """Test processing where some entities succeed and others fail.""" + entities = [] + for i in range(3): + entity = Mock() + entity.id = i + 1 + entity.test_key = f"test-key-{i + 1}" + entity.is_active = True + entities.append(entity) + + mock_context_class.update_data.side_effect = [mock_context, None, mock_context] + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch(entities) + + assert result == 2 + assert mock_context_class.update_data.call_count == 3 + assert mock_write.call_count == 3 + + write_calls = mock_write.call_args_list + assert "Created context for test-key-1" in str(write_calls[0]) + assert "Failed to create context for test-key-2" in str(write_calls[1]) + assert "Created context for test-key-3" in str(write_calls[2]) + + def test_process_context_batch_content_combination(self, command, mock_entity, mock_context): + """Test that metadata and prose content are properly combined.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = mock_context + + with patch.object(command, "extract_content", return_value=("prose", "metadata")): + command.process_context_batch([mock_entity]) + + expected_content = "metadata\n\nprose" + mock_context_class.update_data.assert_called_once() + call_args = mock_context_class.update_data.call_args[1] + assert call_args["content"] == expected_content + + mock_context_class.update_data.reset_mock() + with patch.object(command, "extract_content", return_value=("prose", "")): + command.process_context_batch([mock_entity]) + + call_args = mock_context_class.update_data.call_args[1] + assert call_args["content"] == "prose" + + def test_process_context_batch_metadata_only_content(self, command, mock_entity, mock_context): + """Test process_context_batch with only metadata content.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = mock_context + + with patch.object(command, "extract_content", return_value=("", "metadata")): + command.process_context_batch([mock_entity]) + + expected_content = "metadata\n\n" + call_args = mock_context_class.update_data.call_args[1] + assert call_args["content"] == expected_content + + @patch.object(BaseContextCommand, "get_queryset") + @patch.object(BaseContextCommand, "handle_batch_processing") + def test_handle_method(self, mock_handle_batch, mock_get_queryset, command): + """Test the handle method.""" + mock_queryset = Mock() + mock_get_queryset.return_value = mock_queryset + options = {"batch_size": 10} + + command.handle(**options) + + mock_get_queryset.assert_called_once_with(options) + mock_handle_batch.assert_called_once_with( + queryset=mock_queryset, + batch_size=10, + process_batch_func=command.process_context_batch, + ) + + def test_source_name_usage(self, command, mock_entity, mock_context): + """Test that source_name is properly used in context creation.""" + with ( + patch("apps.ai.common.base.context_command.Context") as mock_context_class, + patch.object(command, "source_name", return_value="custom_source"), + ): + mock_context_class.update_data.return_value = mock_context + + command.process_context_batch([mock_entity]) + + call_args = mock_context_class.update_data.call_args[1] + assert call_args["source"] == "custom_source" + + def test_get_entity_key_usage(self, command, mock_context): + """Test that get_entity_key is properly used for display messages.""" + entity = Mock() + entity.test_key = "custom-entity-key" + + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = mock_context + + with patch.object(command.stdout, "write") as mock_write: + command.process_context_batch([entity]) + + mock_write.assert_called_once_with("Created context for custom-entity-key") + + def test_process_context_batch_empty_list(self, command): + """Test process_context_batch with empty entity list.""" + result = command.process_context_batch([]) + assert result == 0 + + def test_process_context_batch_skips_empty_entities(self, command): + """Test that entities with empty content are properly skipped.""" + entities = [] + + entity1 = Mock() + entity1.test_key = "entity-1" + entities.append(entity1) + + entity2 = Mock() + entity2.test_key = "entity-2" + entities.append(entity2) + + entity3 = Mock() + entity3.test_key = "entity-3" + entities.append(entity3) + + def mock_extract_content(entity): + if entity.test_key == "entity-2": + return ("", "") + return ("prose", "metadata") + + with ( + patch.object(command, "extract_content", side_effect=mock_extract_content), + patch("apps.ai.common.base.context_command.Context") as mock_context_class, + patch.object(command.stdout, "write") as mock_write, + ): + mock_context_class.update_data.return_value = Mock() + + result = command.process_context_batch(entities) + + assert result == 2 + assert mock_context_class.update_data.call_count == 2 + + write_calls = [str(call) for call in mock_write.call_args_list] + assert any("No content for test_entity entity-2" in call for call in write_calls) diff --git a/backend/tests/apps/ai/common/extractors/__init__.py b/backend/tests/apps/ai/common/extractors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/apps/ai/common/extractors/chapter_test.py b/backend/tests/apps/ai/common/extractors/chapter_test.py new file mode 100644 index 0000000000..f97a96ae24 --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/chapter_test.py @@ -0,0 +1,299 @@ +"""Tests for chapter content extractor.""" + +from unittest.mock import MagicMock + +from apps.ai.common.extractors.chapter import extract_chapter_content + + +class TestChapterExtractor: + """Test cases for chapter content extraction.""" + + def test_extract_chapter_content_full_data(self): + """Test extraction with complete chapter data.""" + chapter = MagicMock() + chapter.description = "Test chapter description" + chapter.summary = "Test chapter summary" + chapter.name = "Test Chapter" + chapter.country = "USA" + chapter.region = "North America" + chapter.postal_code = "12345" + chapter.suggested_location = "New York, NY" + chapter.currency = "USD" + chapter.meetup_group = "owasp-nyc" + chapter.tags = ["security", "web"] + chapter.topics = ["OWASP Top 10", "Security Testing"] + chapter.leaders_raw = ["John Doe", "Jane Smith"] + chapter.related_urls = ["https://example.com", "https://github.com/test"] + chapter.invalid_urls = [] + chapter.is_active = True + + repo = MagicMock() + repo.description = "Repository for chapter resources" + repo.topics = ["owasp", "security"] + chapter.owasp_repository = repo + + prose, metadata = extract_chapter_content(chapter) + + assert "Description: Test chapter description" in prose + assert "Summary: Test chapter summary" in prose + assert "Repository Description: Repository for chapter resources" in prose + + assert "Chapter Name: Test Chapter" in metadata + assert ( + "Location Information: Country: USA, Region: North America, " + "Postal Code: 12345, Location: New York, NY" in metadata + ) + assert "Currency: USD" in metadata + assert "Meetup Group: owasp-nyc" in metadata + assert "Tags: security, web" in metadata + assert "Topics: OWASP Top 10, Security Testing" in metadata + assert "Chapter Leaders: John Doe, Jane Smith" in metadata + assert "Related URLs: https://example.com, https://github.com/test" in metadata + assert "Active Chapter: Yes" in metadata + assert "Repository Topics: owasp, security" in metadata + + def test_extract_chapter_content_minimal_data(self): + """Test extraction with minimal chapter data.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Minimal Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert prose == "" + assert "Chapter Name: Minimal Chapter" in metadata + assert "Description:" not in prose + assert "Currency:" not in metadata + assert "Active Chapter: Yes" not in metadata + + def test_extract_chapter_content_empty_fields(self): + """Test extraction with empty string fields.""" + chapter = MagicMock() + chapter.description = "" + chapter.summary = "" + chapter.name = "" + chapter.country = "" + chapter.region = "" + chapter.postal_code = "" + chapter.suggested_location = "" + chapter.currency = "" + chapter.meetup_group = "" + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert prose == "" + assert metadata == "" + + def test_extract_chapter_content_partial_location(self): + """Test extraction with partial location information.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = "Canada" + chapter.region = None + chapter.postal_code = "K1A 0A6" + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = True + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert prose == "" + assert "Chapter Name: Test Chapter" in metadata + assert "Location Information: Country: Canada, Postal Code: K1A 0A6" in metadata + assert "Active Chapter: Yes" in metadata + + def test_extract_chapter_content_with_invalid_urls(self): + """Test extraction with invalid URLs filtered out.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [ + "https://valid.com", + "https://invalid.com", + "https://another-valid.com", + ] + chapter.invalid_urls = ["https://invalid.com"] + chapter.is_active = False + chapter.owasp_repository = None + + _, metadata = extract_chapter_content(chapter) + + assert "Related URLs: https://valid.com, https://another-valid.com" in metadata + assert "https://invalid.com" not in metadata + + def test_extract_chapter_content_repository_no_description(self): + """Test extraction when repository has no description.""" + chapter = MagicMock() + chapter.description = "Chapter description" + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + + repo = MagicMock() + repo.description = None + repo.topics = ["security"] + chapter.owasp_repository = repo + + prose, metadata = extract_chapter_content(chapter) + + assert "Description: Chapter description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics: security" in metadata + + def test_extract_chapter_content_repository_empty_topics(self): + """Test extraction when repository has empty topics.""" + chapter = MagicMock() + chapter.description = "Chapter description" + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + + repo = MagicMock() + repo.description = "Repository description" + repo.topics = [] + chapter.owasp_repository = repo + + prose, metadata = extract_chapter_content(chapter) + + assert "Description: Chapter description" in prose + assert "Repository Description: Repository description" in prose + assert "Repository Topics:" not in metadata + + def test_extract_chapter_content_none_invalid_urls(self): + """Test extraction when invalid_urls is None.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = ["https://valid.com"] + chapter.invalid_urls = None + chapter.is_active = False + chapter.owasp_repository = None + + _, metadata = extract_chapter_content(chapter) + + assert "Related URLs: https://valid.com" in metadata + + def test_extract_chapter_content_empty_related_urls_after_filter(self): + """Test extraction when all related URLs are invalid.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = ["https://invalid1.com", "https://invalid2.com"] + chapter.invalid_urls = ["https://invalid1.com", "https://invalid2.com"] + chapter.is_active = False + chapter.owasp_repository = None + + _, metadata = extract_chapter_content(chapter) + + assert "Related URLs:" not in metadata + + def test_extract_chapter_content_with_none_and_empty_urls(self): + """Test extraction with mix of None and empty URLs.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [ + "https://valid.com", + None, + "", + "https://another-valid.com", + ] + chapter.invalid_urls = [] + chapter.is_active = False + chapter.owasp_repository = None + + _, metadata = extract_chapter_content(chapter) + + assert "Related URLs: https://valid.com, https://another-valid.com" in metadata diff --git a/backend/tests/apps/ai/common/extractors/committee_test.py b/backend/tests/apps/ai/common/extractors/committee_test.py new file mode 100644 index 0000000000..320a451dd0 --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/committee_test.py @@ -0,0 +1,202 @@ +"""Tests for committee content extractor.""" + +from unittest.mock import MagicMock + +from apps.ai.common.extractors.committee import extract_committee_content + + +class TestCommitteeExtractor: + """Test cases for committee content extraction.""" + + def test_extract_committee_content_full_data(self): + """Test extraction with complete committee data.""" + committee = MagicMock() + committee.description = "Test committee description" + committee.summary = "Test committee summary" + committee.name = "Test Committee" + committee.tags = ["governance", "policy"] + committee.topics = ["Security Standards", "Best Practices"] + committee.leaders_raw = ["Alice Johnson", "Bob Wilson"] + committee.related_urls = ["https://committee.example.com"] + committee.invalid_urls = [] + committee.is_active = True + + repo = MagicMock() + repo.description = "Repository for committee resources" + repo.topics = ["governance", "standards"] + committee.owasp_repository = repo + + prose, metadata = extract_committee_content(committee) + + assert "Description: Test committee description" in prose + assert "Summary: Test committee summary" in prose + assert "Repository Description: Repository for committee resources" in prose + + assert "Committee Name: Test Committee" in metadata + assert "Tags: governance, policy" in metadata + assert "Topics: Security Standards, Best Practices" in metadata + assert "Committee Leaders: Alice Johnson, Bob Wilson" in metadata + assert "Related URLs: https://committee.example.com" in metadata + assert "Active Committee: Yes" in metadata + assert "Repository Topics: governance, standards" in metadata + + def test_extract_committee_content_minimal_data(self): + """Test extraction with minimal committee data.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Minimal Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = False + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert prose == "" + assert "Committee Name: Minimal Committee" in metadata + assert "Active Committee: No" in metadata + + def test_extract_committee_content_inactive_committee(self): + """Test extraction with inactive committee.""" + committee = MagicMock() + committee.description = "Inactive committee" + committee.summary = None + committee.name = "Inactive Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = False + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert "Description: Inactive committee" in prose + assert "Active Committee: No" in metadata + + def test_extract_committee_content_with_invalid_urls(self): + """Test extraction with invalid URLs filtered out.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = ["https://valid.com", "https://invalid.com"] + committee.invalid_urls = ["https://invalid.com"] + committee.is_active = True + committee.owasp_repository = None + + _, metadata = extract_committee_content(committee) + + assert "Related URLs: https://valid.com" in metadata + assert "https://invalid.com" not in metadata + + def test_extract_committee_content_no_invalid_urls_attr(self): + """Test extraction when invalid_urls attribute doesn't exist.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = ["https://valid.com"] + committee.is_active = True + committee.owasp_repository = None + del committee.invalid_urls + + _, metadata = extract_committee_content(committee) + + assert "Related URLs: https://valid.com" in metadata + + def test_extract_committee_content_empty_strings(self): + """Test extraction with empty string fields.""" + committee = MagicMock() + committee.description = "" + committee.summary = "" + committee.name = "" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = True + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert prose == "" + assert "Active Committee: Yes" in metadata + assert "Committee Name:" not in metadata + + def test_extract_committee_content_repository_no_description(self): + """Test extraction when repository has no description.""" + committee = MagicMock() + committee.description = "Committee description" + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = ["topic1"] + committee.owasp_repository = repo + + prose, metadata = extract_committee_content(committee) + + assert "Description: Committee description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics: topic1" in metadata + + def test_extract_committee_content_repository_empty_topics(self): + """Test extraction when repository has empty topics.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = True + + repo = MagicMock() + repo.description = "Repo description" + repo.topics = [] + committee.owasp_repository = repo + + prose, metadata = extract_committee_content(committee) + + assert "Repository Description: Repo description" in prose + assert "Repository Topics:" not in metadata + + def test_extract_committee_content_all_empty_after_filter(self): + """Test extraction when all URLs are filtered out.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = ["https://invalid1.com", "https://invalid2.com"] + committee.invalid_urls = ["https://invalid1.com", "https://invalid2.com"] + committee.is_active = True + committee.owasp_repository = None + + _, metadata = extract_committee_content(committee) + + assert "Related URLs:" not in metadata diff --git a/backend/tests/apps/ai/common/extractors/event_test.py b/backend/tests/apps/ai/common/extractors/event_test.py new file mode 100644 index 0000000000..8451674f4e --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/event_test.py @@ -0,0 +1,208 @@ +"""Tests for event content extractor.""" + +from datetime import date +from unittest.mock import MagicMock + +from apps.ai.common.extractors.event import extract_event_content + + +class TestEventExtractor: + """Test cases for event content extraction.""" + + def test_extract_event_content_full_data(self): + """Test extraction with complete event data.""" + event = MagicMock() + event.description = "Test event description" + event.summary = "Test event summary" + event.name = "Test Event" + event.category = "conference" + event.get_category_display.return_value = "Conference" + event.start_date = date(2024, 6, 15) + event.end_date = date(2024, 6, 17) + event.suggested_location = "San Francisco, CA" + event.latitude = 37.7749 + event.longitude = -122.4194 + event.url = "https://event.example.com" + + prose, metadata = extract_event_content(event) + + assert "Description: Test event description" in prose + assert "Summary: Test event summary" in prose + + assert "Event Name: Test Event" in metadata + assert "Category: Conference" in metadata + assert "Start Date: 2024-06-15" in metadata + assert "End Date: 2024-06-17" in metadata + assert "Location: San Francisco, CA" in metadata + assert "Coordinates: 37.7749, -122.4194" in metadata + assert "Event URL: https://event.example.com" in metadata + + def test_extract_event_content_minimal_data(self): + """Test extraction with minimal event data.""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Minimal Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = None + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: Minimal Event" in metadata + assert "Category:" not in metadata + assert "Start Date:" not in metadata + assert "End Date:" not in metadata + assert "Location:" not in metadata + assert "Coordinates:" not in metadata + assert "Event URL:" not in metadata + + def test_extract_event_content_empty_strings(self): + """Test extraction with empty string fields.""" + event = MagicMock() + event.description = "" + event.summary = "" + event.name = "" + event.category = "" + event.get_category_display.return_value = "" + event.start_date = None + event.end_date = None + event.suggested_location = "" + event.latitude = None + event.longitude = None + event.url = "" + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert metadata == "" + + def test_extract_event_content_only_latitude(self): + """Test extraction with only latitude (no coordinates).""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Test Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = 37.7749 + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: Test Event" in metadata + assert "Coordinates:" not in metadata + + def test_extract_event_content_only_longitude(self): + """Test extraction with only longitude (no coordinates).""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Test Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = None + event.longitude = -122.4194 + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: Test Event" in metadata + assert "Coordinates:" not in metadata + + def test_extract_event_content_zero_coordinates(self): + """Test extraction with zero coordinates (should be included).""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Test Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = 0.0 + event.longitude = 0.0 + event.url = None + + _, metadata = extract_event_content(event) + + assert "Event Name: Test Event" in metadata + assert "Coordinates: 0.0, 0.0" in metadata + + def test_extract_event_content_partial_dates(self): + """Test extraction with only start date.""" + event = MagicMock() + event.description = "Event with start date only" + event.summary = None + event.name = "Partial Event" + event.category = "workshop" + event.get_category_display.return_value = "Workshop" + event.start_date = date(2024, 8, 1) + event.end_date = None + event.suggested_location = "Online" + event.latitude = None + event.longitude = None + event.url = "https://online-event.com" + + prose, metadata = extract_event_content(event) + + assert "Description: Event with start date only" in prose + assert "Event Name: Partial Event" in metadata + assert "Category: Workshop" in metadata + assert "Start Date: 2024-08-01" in metadata + assert "End Date:" not in metadata + assert "Location: Online" in metadata + assert "Event URL: https://online-event.com" in metadata + + def test_extract_event_content_only_end_date(self): + """Test extraction with only end date.""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "End Date Event" + event.category = None + event.start_date = None + event.end_date = date(2024, 12, 25) + event.suggested_location = None + event.latitude = None + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: End Date Event" in metadata + assert "Start Date:" not in metadata + assert "End Date: 2024-12-25" in metadata + + def test_extract_event_content_category_display_method(self): + """Test that get_category_display method is called properly.""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Category Test Event" + event.category = "meetup" + event.get_category_display.return_value = "Meetup" + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = None + event.longitude = None + event.url = None + + _, metadata = extract_event_content(event) + + event.get_category_display.assert_called_once() + assert "Category: Meetup" in metadata diff --git a/backend/tests/apps/ai/common/extractors/project_test.py b/backend/tests/apps/ai/common/extractors/project_test.py new file mode 100644 index 0000000000..db5cbfde47 --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/project_test.py @@ -0,0 +1,513 @@ +"""Tests for project content extractor.""" + +from datetime import UTC, datetime +from unittest.mock import MagicMock + +from apps.ai.common.extractors.project import extract_project_content + + +class TestProjectExtractor: + """Test cases for project content extraction.""" + + def test_extract_project_content_full_data(self): + """Test extraction with complete project data.""" + project = MagicMock() + project.description = "Test project description" + project.summary = "Test project summary" + project.name = "Test Project" + project.level = "flagship" + project.type = "tool" + project.languages = ["Python", "JavaScript"] + project.topics = ["security", "web-application-security"] + project.licenses = ["MIT", "Apache-2.0"] + project.tags = ["security", "testing"] + project.custom_tags = ["owasp-top-10"] + project.stars_count = 1500 + project.forks_count = 300 + project.contributors_count = 45 + project.releases_count = 12 + project.open_issues_count = 8 + project.leaders_raw = ["John Doe", "Jane Smith"] + project.related_urls = ["https://project.example.com"] + project.invalid_urls = [] + project.created_at = datetime(2020, 1, 15, tzinfo=UTC) + project.updated_at = datetime(2024, 6, 10, tzinfo=UTC) + project.released_at = datetime(2024, 5, 20, tzinfo=UTC) + project.health_score = 85.75 + project.is_active = True + + repo = MagicMock() + repo.description = "Repository for project resources" + repo.topics = ["security", "python"] + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Description: Test project description" in prose + assert "Summary: Test project summary" in prose + assert "Repository Description: Repository for project resources" in prose + + assert "Project Name: Test Project" in metadata + assert "Project Level: flagship" in metadata + assert "Project Type: tool" in metadata + assert "Programming Languages: Python, JavaScript" in metadata + assert "Topics: security, web-application-security" in metadata + assert "Licenses: MIT, Apache-2.0" in metadata + assert "Tags: security, testing" in metadata + assert "Custom Tags: owasp-top-10" in metadata + assert ( + "Project Statistics: Stars: 1500, Forks: 300, Contributors: 45, " + "Releases: 12, Open Issues: 8" in metadata + ) + assert "Project Leaders: John Doe, Jane Smith" in metadata + assert "Related URLs: https://project.example.com" in metadata + assert "Created: 2020-01-15" in metadata + assert "Last Updated: 2024-06-10" in metadata + assert "Last Release: 2024-05-20" in metadata + assert "Health Score: 85.75" in metadata + assert "Active Project: Yes" in metadata + assert "Repository Topics: security, python" in metadata + + def test_extract_project_content_minimal_data(self): + """Test extraction with minimal project data.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Minimal Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = False + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert prose == "" + assert "Project Name: Minimal Project" in metadata + assert "Active Project: No" in metadata + + def test_extract_project_content_partial_statistics(self): + """Test extraction with partial statistics.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Partial Stats Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = 100 + project.forks_count = None + project.contributors_count = 5 + project.releases_count = None + project.open_issues_count = 3 + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Project Statistics: Stars: 100, Contributors: 5, Open Issues: 3" in metadata + + def test_extract_project_content_zero_statistics(self): + """Test extraction with zero values in statistics.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Zero Stats Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = 0 + project.forks_count = 0 + project.contributors_count = 0 + project.releases_count = 0 + project.open_issues_count = 0 + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Project Statistics:" not in metadata + + def test_extract_project_content_with_invalid_urls(self): + """Test extraction with invalid URLs filtered out.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "URL Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = ["https://valid.com", "https://invalid.com"] + project.invalid_urls = ["https://invalid.com"] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Related URLs: https://valid.com" in metadata + assert "https://invalid.com" not in metadata + + def test_extract_project_content_dates_only(self): + """Test extraction with only date fields.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Date Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = datetime(2021, 3, 1, tzinfo=UTC) + project.updated_at = None + project.released_at = datetime(2023, 8, 15, tzinfo=UTC) + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Created: 2021-03-01" in metadata + assert "Last Updated:" not in metadata + assert "Last Release: 2023-08-15" in metadata + + def test_extract_project_content_health_score_zero(self): + """Test extraction with zero health score.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Zero Health Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = 0.0 + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Health Score: 0.00" in metadata + + def test_extract_project_content_repository_no_description(self): + """Test extraction when repository has no description.""" + project = MagicMock() + project.description = "Project description" + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = ["security"] + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Description: Project description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics: security" in metadata + + def test_extract_project_content_no_invalid_urls_attr(self): + """Test extraction when invalid_urls attribute doesn't exist.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = ["https://valid.com"] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + del project.invalid_urls + + _, metadata = extract_project_content(project) + + assert "Related URLs: https://valid.com" in metadata + + def test_extract_project_content_empty_strings(self): + """Test extraction with empty string fields.""" + project = MagicMock() + project.description = "" + project.summary = "" + project.name = "" + project.level = "" + project.type = "" + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert prose == "" + assert "Active Project: Yes" in metadata + assert "Project Name:" not in metadata + assert "Project Level:" not in metadata + assert "Project Type:" not in metadata + + def test_extract_project_content_repository_with_topics_only(self): + """Test extraction when repository has topics but no description.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = ["security", "python"] + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Repository Description:" not in prose + assert "Repository Topics: security, python" in metadata + + def test_extract_project_content_with_empty_related_urls(self): + """Test extraction with related_urls containing empty strings.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = ["https://valid.com", "", "https://another.com"] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Related URLs: https://valid.com, https://another.com" in metadata + + def test_extract_project_content_repository_no_description_no_topics(self): + """Test extraction when repository exists but has neither description nor topics.""" + project = MagicMock() + project.description = "Project description" + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = None + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Description: Project description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics:" not in metadata + + def test_extract_project_content_created_and_released_only(self): + """Test extraction with created_at and released_at but no updated_at.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Date Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = datetime(2021, 3, 1, tzinfo=UTC) + project.updated_at = None + project.released_at = datetime(2023, 8, 15, tzinfo=UTC) + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Created: 2021-03-01" in metadata + assert "Last Updated:" not in metadata + assert "Last Release: 2023-08-15" in metadata diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index 6cc5057e79..2d18684f44 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -1,6 +1,12 @@ -from unittest.mock import MagicMock, patch +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, call, patch -from apps.ai.common.utils import create_chunks_and_embeddings +import openai + +from apps.ai.common.utils import ( + create_chunks_and_embeddings, + regenerate_chunks_for_context, +) class MockEmbeddingData: @@ -11,8 +17,16 @@ def __init__(self, embedding): class TestUtils: @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") - def test_create_chunks_and_embeddings_success(self, mock_sleep, mock_update_data): + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_success( + self, mock_datetime, mock_sleep, mock_update_data + ): """Tests the successful path where the OpenAI API returns embeddings.""" + base_time = datetime.now(UTC) + mock_datetime.now.return_value = base_time + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + mock_openai_client = MagicMock() mock_api_response = MagicMock() mock_api_response.data = [ @@ -21,7 +35,9 @@ def test_create_chunks_and_embeddings_success(self, mock_sleep, mock_update_data ] mock_openai_client.embeddings.create.return_value = mock_api_response - mock_update_data.return_value = "mock_chunk_instance" + mock_chunk1 = MagicMock() + mock_chunk2 = MagicMock() + mock_update_data.side_effect = [mock_chunk1, mock_chunk2] all_chunk_texts = ["first chunk", "second chunk"] mock_content_object = MagicMock() @@ -37,15 +53,24 @@ def test_create_chunks_and_embeddings_success(self, mock_sleep, mock_update_data model="text-embedding-3-small", ) - assert mock_update_data.call_count == 2 - mock_update_data.assert_any_call( - content_object=mock_content_object, - embedding=[0.1, 0.2], - save=False, - text="first chunk", + mock_update_data.assert_has_calls( + [ + call( + text="first chunk", + embedding=[0.1, 0.2], + context=mock_content_object, + save=True, + ), + call( + text="second chunk", + embedding=[0.3, 0.4], + context=mock_content_object, + save=True, + ), + ] ) - assert result == ["mock_chunk_instance", "mock_chunk_instance"] + assert result == [mock_chunk1, mock_chunk2] mock_sleep.assert_not_called() @@ -53,14 +78,270 @@ def test_create_chunks_and_embeddings_success(self, mock_sleep, mock_update_data def test_create_chunks_and_embeddings_api_error(self, mock_logger): """Tests the failure path where the OpenAI API raises an exception.""" mock_openai_client = MagicMock() - mock_openai_client.embeddings.create.side_effect = Exception("API connection failed") + + mock_openai_client.embeddings.create.side_effect = openai.OpenAIError( + "API connection failed" + ) result = create_chunks_and_embeddings( - all_chunk_texts=["some text"], - content_object=MagicMock(), + chunk_texts=["some text"], + context=MagicMock(), openai_client=mock_openai_client, ) - mock_logger.exception.assert_called_once_with("OpenAI API error") + mock_logger.exception.assert_called_once_with("Failed to create chunks and embeddings") assert result == [] + + @patch("apps.ai.common.utils.logger") + def test_create_chunks_and_embeddings_none_context(self, mock_logger): + """Tests the failure path when context is None.""" + mock_openai_client = MagicMock() + + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_openai_client.embeddings.create.return_value = mock_response + + result = create_chunks_and_embeddings( + chunk_texts=["some text"], + context=None, + openai_client=mock_openai_client, + ) + + mock_logger.exception.assert_called_once_with("Failed to create chunks and embeddings") + assert result == [] + + @patch("apps.ai.common.utils.Chunk.update_data") + @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_sleep_called( + self, mock_datetime, mock_sleep, mock_update_data + ): + """Tests that sleep is called when needed.""" + base_time = datetime.now(UTC) + mock_datetime.now.return_value = base_time + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] + mock_openai_client.embeddings.create.return_value = mock_api_response + + mock_chunk_instance = MagicMock() + mock_update_data.return_value = mock_chunk_instance + + mock_content_object = MagicMock() + + result = create_chunks_and_embeddings( + ["test chunk"], + mock_content_object, + mock_openai_client, + ) + + mock_sleep.assert_not_called() + assert result == [mock_chunk_instance] + + @patch("apps.ai.common.utils.Chunk.update_data") + @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_no_sleep_with_current_settings( + self, mock_datetime, mock_sleep, mock_update_data + ): + """Tests that sleep is not called with current offset settings.""" + base_time = datetime.now(UTC) + mock_datetime.now.return_value = base_time + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] + mock_openai_client.embeddings.create.return_value = mock_api_response + + mock_chunk_instance = MagicMock() + mock_update_data.return_value = mock_chunk_instance + mock_context_obj = MagicMock() + + result = create_chunks_and_embeddings( + ["test chunk"], + mock_context_obj, + mock_openai_client, + ) + + mock_sleep.assert_not_called() + mock_update_data.assert_called_once_with( + text="test chunk", embedding=[0.1, 0.2], context=mock_context_obj, save=True + ) + assert result == [mock_chunk_instance] + + @patch("apps.ai.common.utils.Chunk.update_data") + @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_sleep_when_rate_limited( + self, mock_datetime, mock_sleep, mock_update_data + ): + """Tests that sleep is called when rate limiting is needed.""" + base_time = datetime.now(UTC) + + mock_datetime.now.side_effect = [ + base_time, + base_time, + ] + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] + mock_openai_client.embeddings.create.return_value = mock_api_response + + mock_chunk_instance = MagicMock() + mock_update_data.return_value = mock_chunk_instance + + with ( + patch("apps.ai.common.utils.DEFAULT_LAST_REQUEST_OFFSET_SECONDS", 5), + patch("apps.ai.common.utils.MIN_REQUEST_INTERVAL_SECONDS", 10), + ): + mock_context_obj = MagicMock() + + result = create_chunks_and_embeddings( + ["test chunk"], + mock_context_obj, + mock_openai_client, + ) + + mock_sleep.assert_called_once() + assert result == [mock_chunk_instance] + + @patch("apps.ai.common.utils.Chunk.update_data") + def test_create_chunks_and_embeddings_filter_none_chunks(self, mock_update_data): + """Tests that None chunks are filtered out from results.""" + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [ + MockEmbeddingData([0.1, 0.2]), + MockEmbeddingData([0.3, 0.4]), + ] + mock_openai_client.embeddings.create.return_value = mock_api_response + + mock_chunk_instance = MagicMock() + mock_update_data.side_effect = [mock_chunk_instance, None] + + mock_context_obj = MagicMock() + + result = create_chunks_and_embeddings( + ["first chunk", "second chunk"], + mock_context_obj, + mock_openai_client, + ) + + assert len(result) == 1 + assert result[0] == mock_chunk_instance + + +class TestRegenerateChunksForContext: + """Test cases for regenerate_chunks_for_context function.""" + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.common.utils.OpenAI") + @patch("apps.ai.common.utils.create_chunks_and_embeddings") + @patch("apps.ai.common.utils.Chunk.split_text") + def test_regenerate_chunks_for_context_success( + self, mock_split_text, mock_create_chunks, mock_openai_class, mock_logger + ): + """Test successful regeneration of chunks for context.""" + context = MagicMock() + context.content = "This is test content for chunking" + + mock_existing_chunks = MagicMock() + context.chunks = mock_existing_chunks + + new_chunk_texts = ["chunk1", "chunk2"] + mock_split_text.return_value = new_chunk_texts + + mock_openai_client = MagicMock() + mock_openai_class.return_value = mock_openai_client + + regenerate_chunks_for_context(context) + + mock_existing_chunks.all.assert_called_once() + mock_existing_chunks.all().delete.assert_called_once() + + mock_split_text.assert_called_once_with(context.content) + + mock_openai_class.assert_called_once() + + mock_create_chunks.assert_called_once_with( + chunk_texts=new_chunk_texts, + context=context, + openai_client=mock_openai_client, + save=True, + ) + + mock_logger.info.assert_called_once_with( + "Successfully completed chunk regeneration for new context" + ) + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.common.utils.Chunk.split_text") + def test_regenerate_chunks_for_context_no_content(self, mock_split_text, mock_logger): + """Test regeneration when there's no content to chunk.""" + context = MagicMock() + context.content = "Some content" + + mock_existing_chunks = MagicMock() + context.chunks = mock_existing_chunks + + mock_split_text.return_value = [] + + regenerate_chunks_for_context(context) + + mock_existing_chunks.all.assert_called_once() + mock_existing_chunks.all().delete.assert_called_once() + + mock_split_text.assert_called_once_with(context.content) + + mock_logger.warning.assert_called_once_with( + "No content to chunk for Context. Process stopped." + ) + + mock_logger.info.assert_not_called() + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.common.utils.OpenAI") + @patch("apps.ai.common.utils.create_chunks_and_embeddings") + @patch("apps.ai.common.utils.Chunk.split_text") + def test_regenerate_chunks_for_context_no_existing_chunks( + self, mock_split_text, mock_create_chunks, mock_openai_class, mock_logger + ): + """Test regeneration when there are no existing chunks.""" + context = MagicMock() + context.content = "This is test content for chunking" + + mock_existing_chunks = MagicMock() + context.chunks = mock_existing_chunks + + new_chunk_texts = ["chunk1", "chunk2"] + mock_split_text.return_value = new_chunk_texts + + mock_openai_client = MagicMock() + mock_openai_class.return_value = mock_openai_client + + regenerate_chunks_for_context(context) + + mock_existing_chunks.all.assert_called_once() + mock_existing_chunks.all().delete.assert_called_once() + + mock_split_text.assert_called_once_with(context.content) + + mock_create_chunks.assert_called_once_with( + chunk_texts=new_chunk_texts, + context=context, + openai_client=mock_openai_client, + save=True, + ) + + mock_logger.info.assert_called_once_with( + "Successfully completed chunk regeneration for new context" + ) diff --git a/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py b/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py new file mode 100644 index 0000000000..33017e3169 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py @@ -0,0 +1,142 @@ +"""Tests for the ai_run_rag_tool Django management command.""" + +from unittest.mock import MagicMock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_run_rag_tool import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +class TestAiRunRagToolCommand: + """Test suite for the ai_run_rag_tool command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Test the RagTool functionality with a sample query" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 6 + parser.add_argument.assert_any_call( + "--query", + type=str, + default="What is OWASP Foundation?", + help="Query to test the Rag tool", + ) + parser.add_argument.assert_any_call( + "--limit", + type=int, + default=5, # DEFAULT_CHUNKS_RETRIEVAL_LIMIT + help="Maximum number of results to retrieve", + ) + parser.add_argument.assert_any_call( + "--threshold", + type=float, + default=0.4, # DEFAULT_SIMILARITY_THRESHOLD + help="Similarity threshold (0.0 to 1.0)", + ) + parser.add_argument.assert_any_call( + "--content-types", + nargs="+", + default=None, + help="Content types to filter by (e.g., project chapter)", + ) + parser.add_argument.assert_any_call( + "--embedding-model", + type=str, + default="text-embedding-3-small", + help="OpenAI embedding model", + ) + parser.add_argument.assert_any_call( + "--chat-model", + type=str, + default="gpt-4o", + help="OpenAI chat model", + ) + + @patch("apps.ai.management.commands.ai_run_rag_tool.RagTool") + def test_handle_success(self, mock_rag_tool, command): + """Test successful command execution.""" + command.stdout = MagicMock() + mock_rag_instance = MagicMock() + mock_rag_instance.query.return_value = "Test answer" + mock_rag_tool.return_value = mock_rag_instance + + command.handle( + query="Test query", + limit=10, + threshold=0.8, + content_types=["project", "chapter"], + embedding_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + + mock_rag_tool.assert_called_once_with( + chat_model="gpt-4o", embedding_model="text-embedding-3-small" + ) + mock_rag_instance.query.assert_called_once_with( + content_types=["project", "chapter"], + limit=10, + question="Test query", + similarity_threshold=0.8, + ) + command.stdout.write.assert_any_call("\nProcessing query...") + command.stdout.write.assert_any_call("\nAnswer: Test answer") + + @patch("apps.ai.management.commands.ai_run_rag_tool.RagTool") + def test_handle_initialization_error(self, mock_rag_tool, command): + """Test command when RagTool initialization fails.""" + command.stderr = MagicMock() + command.style = MagicMock() + mock_rag_tool.side_effect = ValueError("Initialization error") + + command.handle( + query="What is OWASP Foundation?", + limit=5, + threshold=0.5, + content_types=None, + embedding_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + command.stderr.write.assert_called_once() + + @patch("apps.ai.management.commands.ai_run_rag_tool.RagTool") + def test_handle_with_default_values(self, mock_rag_tool, command): + """Test command with default argument values.""" + command.stdout = MagicMock() + mock_rag_instance = MagicMock() + mock_rag_instance.query.return_value = "Default answer" + mock_rag_tool.return_value = mock_rag_instance + + command.handle( + query="What is OWASP Foundation?", + limit=5, + threshold=0.5, + content_types=None, + embedding_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + + mock_rag_tool.assert_called_once_with( + chat_model="gpt-4o", embedding_model="text-embedding-3-small" + ) + mock_rag_instance.query.assert_called_once_with( + content_types=None, + limit=5, # DEFAULT_CHUNKS_RETRIEVAL_LIMIT + question="What is OWASP Foundation?", + similarity_threshold=0.5, # DEFAULT_SIMILARITY_THRESHOLD + ) diff --git a/backend/tests/apps/ai/management/commands/ai_update_chapter_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_chapter_chunks_test.py new file mode 100644 index 0000000000..7df18a89bd --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_chapter_chunks_test.py @@ -0,0 +1,47 @@ +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_update_chapter_chunks import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_chapter(): + chapter = Mock() + chapter.id = 1 + chapter.key = "test-chapter" + return chapter + + +class TestAiCreateChapterChunksCommand: + def test_command_inheritance(self, command): + assert isinstance(command, BaseCommand) + + def test_model_class_property(self, command): + from apps.owasp.models.chapter import Chapter + + assert command.model_class == Chapter + + def test_entity_name_property(self, command): + assert command.entity_name == "chapter" + + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "chapters" + + def test_key_field_name_property(self, command): + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_chapter): + with patch( + "apps.ai.management.commands.ai_update_chapter_chunks.extract_chapter_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_chapter) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_chapter) diff --git a/backend/tests/apps/ai/management/commands/ai_update_chapter_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_chapter_context_test.py new file mode 100644 index 0000000000..792e2f2f02 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_chapter_context_test.py @@ -0,0 +1,60 @@ +"""Tests for the ai_create_chapter_context Django management command.""" + +from unittest.mock import Mock, patch + +import pytest + +from apps.ai.management.commands.ai_update_chapter_context import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_chapter(): + chapter = Mock() + chapter.id = 1 + chapter.key = "test-chapter" + return chapter + + +class TestAiCreateChapterContextCommand: + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base.context_command import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help() == "Update context for OWASP chapter data" + + def test_model_class_property(self, command): + """Test the model_class property returns Chapter.""" + from apps.owasp.models.chapter import Chapter + + assert command.model_class == Chapter + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "chapter" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "chapters" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_chapter): + """Test the extract_content method.""" + with patch( + "apps.ai.management.commands.ai_update_chapter_context.extract_chapter_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_chapter) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_chapter) diff --git a/backend/tests/apps/ai/management/commands/ai_update_committee_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_committee_chunks_test.py new file mode 100644 index 0000000000..e889fe5b11 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_committee_chunks_test.py @@ -0,0 +1,149 @@ +"""Tests for the ai_create_committee_chunks command.""" + +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_update_committee_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_committee(): + """Return a mock Committee instance.""" + committee = Mock() + committee.id = 1 + committee.key = "test-committee" + committee.name = "Test Committee" + committee.description = "Test committee description" + committee.is_active = True + return committee + + +class TestAiCreateCommitteeChunksCommand: + """Test suite for the ai_create_committee_chunks command.""" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_model_class_method(self, command): + """Test the model_class method returns Committee.""" + from apps.owasp.models.committee import Committee + + assert command.model_class == Committee + + def test_entity_name_method(self, command): + """Test the entity_name method.""" + assert command.entity_name == "committee" + + def test_entity_name_plural_method(self, command): + """Test the entity_name_plural method.""" + assert command.entity_name_plural == "committees" + + def test_key_field_name_method(self, command): + """Test the key_field_name method.""" + assert command.key_field_name == "key" + + def test_extract_content_method(self, command, mock_committee): + """Test the extract_content method.""" + with patch( + "apps.ai.management.commands.ai_update_committee_chunks.extract_committee_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_committee) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_committee) + + def test_get_base_queryset_calls_super(self, command): + """Test that get_base_queryset calls the parent method.""" + with patch( + "apps.ai.common.base.chunk_command.BaseChunkCommand.get_base_queryset" + ) as mock_super: + mock_super.return_value = "base_queryset" + result = command.get_base_queryset() + assert result == "base_queryset" + mock_super.assert_called_once() + + def test_get_default_queryset_filters_active(self, command): + """Test that get_default_queryset filters for active committees.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + result = command.get_default_queryset() + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(is_active=True) + + def test_add_arguments_calls_super(self, command): + """Test that add_arguments calls the parent method.""" + mock_parser = Mock() + with patch.object(command, "add_common_arguments") as mock_add_common: + command.add_arguments(mock_parser) + mock_add_common.assert_called_once_with(mock_parser) + + def test_get_queryset_with_committee_key(self, command): + """Test get_queryset with committee key option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + options = {"committee_key": "specific-committee"} + result = command.get_queryset(options) + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(key="specific-committee") + + def test_get_queryset_with_all_option(self, command): + """Test get_queryset with all option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + + options = {"all": True} + result = command.get_queryset(options) + + assert result == mock_queryset + + def test_get_queryset_default_behavior(self, command): + """Test get_queryset with default behavior.""" + with patch.object(command, "get_default_queryset") as mock_get_default: + mock_get_default.return_value = "default_queryset" + + options = {} + result = command.get_queryset(options) + + assert result == "default_queryset" + + def test_get_entity_key_returns_key(self, command, mock_committee): + """Test get_entity_key returns the committee key.""" + result = command.get_entity_key(mock_committee) + assert result == "test-committee" + + def test_get_entity_key_fallback_to_pk(self, command): + """Test get_entity_key falls back to pk when key field doesn't exist.""" + mock_committee = Mock() + mock_committee.pk = 123 + + if hasattr(mock_committee, "key"): + delattr(mock_committee, "key") + + result = command.get_entity_key(mock_committee) + assert result == "123" + + def test_source_name_default(self, command): + """Test default source name.""" + assert command.source_name() == "owasp_committee" + + def test_help_method(self, command): + """Test the help method.""" + assert command.help() == "Create or update chunks for OWASP committee data" diff --git a/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py new file mode 100644 index 0000000000..6c1940dca1 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py @@ -0,0 +1,230 @@ +"""Tests for the ai_create_committee_context command.""" + +from unittest.mock import Mock, patch + +import pytest + +from apps.ai.management.commands.ai_update_committee_context import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_committee(): + """Return a mock Committee instance.""" + committee = Mock() + committee.id = 1 + committee.key = "test-committee" + committee.name = "Test Committee" + committee.description = "Test committee description" + committee.is_active = True + return committee + + +class TestAiCreateCommitteeContextCommand: + """Test suite for the ai_create_committee_context command.""" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base.context_command import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help() == "Update context for OWASP committee data" + + def test_model_class_method(self, command): + """Test the model_class method returns Committee.""" + from apps.owasp.models.committee import Committee + + assert command.model_class == Committee + + def test_entity_name_method(self, command): + """Test the entity_name method.""" + assert command.entity_name == "committee" + + def test_entity_name_plural_method(self, command): + """Test the entity_name_plural method.""" + assert command.entity_name_plural == "committees" + + def test_key_field_name_method(self, command): + """Test the key_field_name method.""" + assert command.key_field_name == "key" + + def test_extract_content_method(self, command, mock_committee): + """Test the extract_content method.""" + with patch( + "apps.ai.management.commands.ai_update_committee_context.extract_committee_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_committee) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_committee) + + def test_get_base_queryset_calls_super(self, command): + """Test that get_base_queryset calls the parent method.""" + with patch( + "apps.ai.common.base.context_command.BaseContextCommand.get_base_queryset" + ) as mock_super: + mock_super.return_value = "base_queryset" + result = command.get_base_queryset() + assert result == "base_queryset" + mock_super.assert_called_once() + + def test_get_default_queryset_filters_active(self, command): + """Test that get_default_queryset filters for active committees.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + result = command.get_default_queryset() + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(is_active=True) + + def test_add_arguments_calls_super(self, command): + """Test that add_arguments calls the parent method.""" + mock_parser = Mock() + with patch.object(command, "add_common_arguments") as mock_add_common: + command.add_arguments(mock_parser) + mock_add_common.assert_called_once_with(mock_parser) + + def test_get_queryset_with_committee_key(self, command): + """Test get_queryset with committee key option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + options = {"committee_key": "specific-committee"} + result = command.get_queryset(options) + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(key="specific-committee") + + def test_get_queryset_with_all_option(self, command): + """Test get_queryset with all option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + + options = {"all": True} + result = command.get_queryset(options) + + assert result == mock_queryset + + def test_get_queryset_default_behavior(self, command): + """Test get_queryset with default behavior.""" + with patch.object(command, "get_default_queryset") as mock_get_default: + mock_get_default.return_value = "default_queryset" + + options = {} + result = command.get_queryset(options) + + assert result == "default_queryset" + + def test_get_entity_key_returns_key(self, command, mock_committee): + """Test get_entity_key returns the committee key.""" + result = command.get_entity_key(mock_committee) + assert result == "test-committee" + + def test_get_entity_key_fallback_to_pk(self, command): + """Test get_entity_key falls back to pk when key field doesn't exist.""" + mock_committee = Mock() + mock_committee.pk = 123 + + if hasattr(mock_committee, "key"): + delattr(mock_committee, "key") + + result = command.get_entity_key(mock_committee) + assert result == "123" + + def test_source_name_default(self, command): + """Test default source name.""" + assert command.source_name() == "owasp_committee" + + def test_process_context_batch_success(self, command, mock_committee): + """Test successful context batch processing.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = True + + with patch.object(command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "test-committee" + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_committee]) + + assert result == 1 + mock_context_class.update_data.assert_called_once_with( + content="Metadata\n\nContent", + entity=mock_committee, + source="owasp_committee", + ) + mock_write.assert_called_once_with("Created context for test-committee") + + def test_process_context_batch_empty_content(self, command, mock_committee): + """Test context batch processing with empty content.""" + with patch.object(command, "extract_content") as mock_extract: + mock_extract.return_value = ("", "") + + with patch.object(command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "test-committee" + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_committee]) + + assert result == 0 + mock_write.assert_called_once_with("No content for committee test-committee") + + def test_process_context_batch_create_failure(self, command, mock_committee): + """Test context batch processing when Context.update_data fails.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = False + + with patch.object(command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "test-committee" + + with ( + patch.object(command.stdout, "write") as mock_write, + patch.object(command.style, "ERROR") as mock_error, + ): + mock_error.return_value = "ERROR: Failed" + + result = command.process_context_batch([mock_committee]) + + assert result == 0 + mock_error.assert_called_once_with( + "Failed to create context for test-committee" + ) + mock_write.assert_called_once_with("ERROR: Failed") + + def test_handle_calls_batch_processing(self, command): + """Test that handle method calls batch processing.""" + mock_queryset = Mock() + mock_queryset.count.return_value = 2 + + with patch.object(command, "get_queryset") as mock_get_queryset: + mock_get_queryset.return_value = mock_queryset + + with patch.object(command, "handle_batch_processing") as mock_batch_processing: + options = {"batch_size": 50} + command.handle(**options) + + mock_get_queryset.assert_called_once_with(options) + mock_batch_processing.assert_called_once_with( + queryset=mock_queryset, + batch_size=50, + process_batch_func=command.process_context_batch, + ) diff --git a/backend/tests/apps/ai/management/commands/ai_update_event_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_event_chunks_test.py new file mode 100644 index 0000000000..298b50b185 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_event_chunks_test.py @@ -0,0 +1,78 @@ +"""Tests for the ai_create_event_chunks Django management command.""" + +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_update_event_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_event(): + """Return a mock Event instance.""" + event = Mock() + event.id = 1 + event.key = "test-event" + return event + + +class TestAiCreateEventChunksCommand: + """Test suite for the ai_create_event_chunks command.""" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_model_class_property(self, command): + """Test the model_class property returns Event.""" + from apps.owasp.models.event import Event + + assert command.model_class == Event + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "event" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "events" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_event): + """Test content extraction from event.""" + with patch( + "apps.ai.management.commands.ai_update_event_chunks.extract_event_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_event) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_event) + + def test_get_base_queryset(self, command): + """Test get_base_queryset calls super().get_base_queryset().""" + with patch( + "apps.ai.common.base.ai_command.BaseAICommand.get_base_queryset", + return_value="base_qs", + ) as mock_super: + result = command.get_base_queryset() + assert result == "base_qs" + mock_super.assert_called_once() + + def test_get_default_queryset(self, command): + """Test that the default queryset returns upcoming events.""" + with patch("apps.owasp.models.event.Event.upcoming_events") as mock_upcoming: + mock_queryset = Mock() + mock_upcoming.return_value = mock_queryset + result = command.get_default_queryset() + assert result == mock_queryset + mock_upcoming.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_update_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_event_context_test.py new file mode 100644 index 0000000000..8f4d2d12ce --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_event_context_test.py @@ -0,0 +1,77 @@ +"""Tests for the ai_create_event_context Django management command.""" + +from unittest.mock import Mock, patch + +import pytest + +from apps.ai.management.commands.ai_update_event_context import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_event(): + event = Mock() + event.id = 1 + event.key = "test-event" + return event + + +class TestAiCreateEventContextCommand: + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base.context_command import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help() == "Update context for OWASP event data" + + def test_model_class_property(self, command): + """Test the model_class property returns Event.""" + from apps.owasp.models.event import Event + + assert command.model_class == Event + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "event" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "events" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_event): + """Test the extract_content method.""" + with patch( + "apps.ai.management.commands.ai_update_event_context.extract_event_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_event) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_event) + + def test_get_default_queryset(self, command): + """Test that the default queryset returns upcoming events.""" + with patch("apps.owasp.models.event.Event.upcoming_events") as mock_upcoming: + mock_queryset = Mock() + mock_upcoming.return_value = mock_queryset + result = command.get_default_queryset() + assert result == mock_queryset + mock_upcoming.assert_called_once() + + def test_get_base_queryset(self, command): + """Test the get_base_queryset method.""" + with patch.object(command.__class__.__bases__[0], "get_base_queryset") as mock_super: + mock_super.return_value = Mock() + result = command.get_base_queryset() + mock_super.assert_called_once() + assert result == mock_super.return_value diff --git a/backend/tests/apps/ai/management/commands/ai_update_project_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_project_chunks_test.py new file mode 100644 index 0000000000..85f889ddaf --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_project_chunks_test.py @@ -0,0 +1,57 @@ +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_update_project_chunks import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_project(): + project = Mock() + project.id = 1 + project.key = "test-project" + return project + + +class TestAiCreateProjectChunksCommand: + def test_command_inheritance(self, command): + assert isinstance(command, BaseCommand) + + def test_model_class_property(self, command): + from apps.owasp.models.project import Project + + assert command.model_class == Project + + def test_entity_name_property(self, command): + assert command.entity_name == "project" + + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "projects" + + def test_key_field_name_property(self, command): + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_project): + with patch( + "apps.ai.management.commands.ai_update_project_chunks.extract_project_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_project) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_project) + + def test_get_base_queryset_calls_super(self, command): + """Test that get_base_queryset calls the parent method.""" + with patch( + "apps.ai.common.base.chunk_command.BaseChunkCommand.get_base_queryset" + ) as mock_super: + mock_super.return_value = "base_queryset" + result = command.get_base_queryset() + assert result == "base_queryset" + mock_super.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_update_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_project_context_test.py new file mode 100644 index 0000000000..82df257634 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_project_context_test.py @@ -0,0 +1,66 @@ +from unittest.mock import Mock, patch + +import pytest + +from apps.ai.management.commands.ai_update_project_context import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_project(): + project = Mock() + project.id = 1 + project.key = "test-project" + return project + + +class TestAiCreateProjectContextCommand: + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base.context_command import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help() == "Update context for OWASP project data" + + def test_model_class_property(self, command): + """Test the model_class property returns Project.""" + from apps.owasp.models.project import Project + + assert command.model_class == Project + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "project" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "projects" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_project): + """Test the extract_content method.""" + with patch( + "apps.ai.management.commands.ai_update_project_context.extract_project_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_project) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_project) + + def test_get_base_queryset(self, command): + """Test the get_base_queryset method.""" + with patch.object(command.__class__.__bases__[0], "get_base_queryset") as mock_super: + mock_super.return_value = Mock() + result = command.get_base_queryset() + mock_super.assert_called_once() + assert result == mock_super.return_value diff --git a/backend/tests/apps/ai/management/commands/ai_update_slack_message_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_slack_message_chunks_test.py new file mode 100644 index 0000000000..22c9db9772 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_slack_message_chunks_test.py @@ -0,0 +1,83 @@ +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_update_slack_message_chunks import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_message(): + message = Mock() + message.id = 1 + message.slack_message_id = "test-message-id" + return message + + +class TestAiCreateSlackMessageChunksCommand: + def test_command_inheritance(self, command): + assert isinstance(command, BaseCommand) + + def test_model_class_property(self, command): + from apps.slack.models.message import Message + + assert command.model_class == Message + + def test_entity_name_property(self, command): + assert command.entity_name == "message" + + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "messages" + + def test_key_field_name_property(self, command): + assert command.key_field_name == "slack_message_id" + + def test_source_name_property(self, command): + """Test the source_name property.""" + assert command.source_name() == "slack_message" + + def test_extract_content(self, command, mock_message): + """Test the extract_content method.""" + mock_message.cleaned_text = "Test message content" + content = command.extract_content(mock_message) + assert content == ("Test message content", "") + + def test_extract_content_none_cleaned_text(self, command, mock_message): + """Test the extract_content method with None cleaned_text.""" + mock_message.cleaned_text = None + content = command.extract_content(mock_message) + assert content == ("", "") + + def test_get_default_queryset(self, command): + """Test the get_default_queryset method.""" + with patch.object(command, "get_base_queryset") as mock_base: + mock_base.return_value = Mock() + result = command.get_default_queryset() + mock_base.assert_called_once() + assert result == mock_base.return_value + + def test_add_arguments(self, command): + """Test the add_arguments method.""" + parser = Mock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + calls = parser.add_argument.call_args_list + + assert calls[0][0] == ("--message-key",) + assert calls[0][1]["type"] is str + assert "Process only the message with this key" in calls[0][1]["help"] + + assert calls[1][0] == ("--all",) + assert calls[1][1]["action"] == "store_true" + assert "Process all the messages" in calls[1][1]["help"] + + assert calls[2][0] == ("--batch-size",) + assert calls[2][1]["type"] is int + assert calls[2][1]["default"] == 100 + assert "Number of messages to process in each batch" in calls[2][1]["help"] diff --git a/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py new file mode 100644 index 0000000000..2d6ea5fec3 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py @@ -0,0 +1,104 @@ +from unittest.mock import Mock, patch + +import pytest + +from apps.ai.management.commands.ai_update_slack_message_context import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_message(): + message = Mock() + message.id = 1 + message.slack_message_id = "test-message-id" + return message + + +class TestAiCreateSlackMessageContextCommand: + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base.context_command import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_model_class_property(self, command): + """Test the model_class property returns Message.""" + from apps.slack.models.message import Message + + assert command.model_class == Message + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "message" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "messages" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "slack_message_id" + + def test_source_name_property(self, command): + """Test the source_name property.""" + assert command.source_name() == "slack_message" + + def test_extract_content(self, command, mock_message): + """Test the extract_content method.""" + mock_message.cleaned_text = "Test message content" + content = command.extract_content(mock_message) + assert content == ("Test message content", "") + + def test_extract_content_none_cleaned_text(self, command, mock_message): + """Test the extract_content method with None cleaned_text.""" + mock_message.cleaned_text = None + content = command.extract_content(mock_message) + assert content == ("", "") + + def test_get_default_queryset(self, command): + """Test the get_default_queryset method.""" + with patch.object(command, "get_base_queryset") as mock_base: + mock_base.return_value = Mock() + result = command.get_default_queryset() + mock_base.assert_called_once() + assert result == mock_base.return_value + + def test_add_arguments(self, command): + """Test the add_arguments method.""" + parser = Mock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 6 + calls = parser.add_argument.call_args_list + + # First 3 calls are from parent class (BaseAICommand) + assert calls[0][0] == ("--message-key",) + assert calls[0][1]["type"] is str + assert "Process only the message with this key" in calls[0][1]["help"] + + assert calls[1][0] == ("--all",) + assert calls[1][1]["action"] == "store_true" + assert "Process all the messages" in calls[1][1]["help"] + + assert calls[2][0] == ("--batch-size",) + assert calls[2][1]["type"] is int + assert calls[2][1]["default"] == 50 # Default from parent class + assert "Number of messages to process in each batch" in calls[2][1]["help"] + + # Next 3 calls are from the command itself (duplicates with different defaults) + assert calls[3][0] == ("--message-key",) + assert calls[3][1]["type"] is str + assert "Process only the message with this key" in calls[3][1]["help"] + + assert calls[4][0] == ("--all",) + assert calls[4][1]["action"] == "store_true" + assert "Process all the messages" in calls[4][1]["help"] + + assert calls[5][0] == ("--batch-size",) + assert calls[5][1]["type"] is int + assert calls[5][1]["default"] == 100 # Overridden default from command + assert "Number of messages to process in each batch" in calls[5][1]["help"] diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index dca223f800..d3c1fc61c9 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -1,39 +1,32 @@ from unittest.mock import Mock, patch -from django.contrib.contenttypes.models import ContentType -from django.db import models +import pytest from apps.ai.models.chunk import Chunk -from apps.slack.models.message import Message +from apps.ai.models.context import Context -def create_model_mock(model_class): - mock = Mock(spec=model_class) - mock._state = Mock() - mock.pk = 1 +@pytest.fixture +def mock_context(): + mock = Mock(spec=Context) mock.id = 1 + mock.entity_type = Mock() + mock.entity_id = 1 return mock class TestChunkModel: def test_str_method(self): - mock_message = create_model_mock(Message) - mock_message.name = "Test Message" - - mock_content_type = Mock(spec=ContentType) - mock_content_type.model = "message" - - with ( - patch.object(Chunk, "content_type", mock_content_type), - patch.object(Chunk, "content_object", mock_message), - ): - chunk = Chunk() - chunk.id = 1 - chunk.text = "This is a test chunk with some content that should be displayed" - - result = str(chunk) - assert "Chunk 1 for message Test Message:" in result - assert "This is a test chunk with some content that" in result + mock_context = Mock(spec=Context) + mock_context.__str__ = Mock(return_value="Context 1 for message Test Message: ...") + mock_context._state = Mock() + chunk = Chunk() + chunk.id = 1 + chunk.text = "This is a test chunk with some content that should be displayed" + chunk.context = mock_context + result = str(chunk) + assert "Chunk 1 for Context 1 for message Test Message:" in result + assert "This is a test chunk with some content that" in result def test_bulk_save_with_chunks(self): mock_chunks = [Mock(), Mock(), Mock()] @@ -60,118 +53,105 @@ def test_split_text(self): assert all(isinstance(chunk, str) for chunk in result) @patch("apps.ai.models.chunk.Chunk.save") - @patch("apps.ai.models.chunk.Chunk.__init__") - def test_update_data_new_chunk(self, mock_init, mock_save, mocker): - mock_init.return_value = None - - mock_message = create_model_mock(Message) + def test_update_data_save_with_context(self, mock_save): text = "Test chunk content" embedding = [0.1, 0.2, 0.3] - - mock_content_type = Mock(spec=ContentType) - mock_get_for_model = mocker.patch( - "django.contrib.contenttypes.models.ContentType.objects.get_for_model", - return_value=mock_content_type, - ) - - mock_filter = mocker.patch( - "apps.ai.models.chunk.Chunk.objects.filter", - return_value=Mock(exists=Mock(return_value=False)), - ) - - result = Chunk.update_data( - text=text, content_object=mock_message, embedding=embedding, save=True - ) - - mock_get_for_model.assert_called_once_with(mock_message) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=mock_message.id, text=text - ) - mock_init.assert_called_once_with( - content_type=mock_content_type, - object_id=mock_message.id, - text=text, - embedding=embedding, - ) - mock_save.assert_called_once() - - assert result is not None - assert isinstance(result, Chunk) - - def test_update_data_existing_chunk(self, mocker): - mock_message = create_model_mock(Message) - text = "Existing chunk content" + mock_context = Mock(spec=Context) + + with patch("apps.ai.models.chunk.Chunk") as mock_chunk: + chunk_instance = Mock() + chunk_instance.context_id = 123 + mock_chunk.return_value = chunk_instance + mock_chunk.objects.filter.return_value.exists.return_value = False + + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=True + ) + + mock_chunk.assert_called_once_with( + text=text, embedding=embedding, context=mock_context + ) + chunk_instance.save.assert_called_once() + assert result is chunk_instance + + def test_update_data_creates_new_chunk_and_saves(self, mock_context): + """Test that a new chunk is created and saved.""" + text = "New unique chunk content" embedding = [0.1, 0.2, 0.3] - mock_content_type = Mock(spec=ContentType) - mock_get_for_model = mocker.patch( - "django.contrib.contenttypes.models.ContentType.objects.get_for_model", - return_value=mock_content_type, - ) - - mock_filter = mocker.patch( - "apps.ai.models.chunk.Chunk.objects.filter", - return_value=Mock(exists=Mock(return_value=True)), - ) + with patch("apps.ai.models.chunk.Chunk") as mock_chunk_class: + mock_chunk_class.objects.filter.return_value.exists.return_value = False + mock_instance = Mock(spec=Chunk) + mock_chunk_class.return_value = mock_instance + + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=True + ) + + mock_chunk_class.objects.filter.assert_called_once_with( + context__entity_type=mock_context.entity_type, + context__entity_id=mock_context.entity_id, + text=text, + ) + mock_chunk_class.assert_called_once_with( + text=text, embedding=embedding, context=mock_context + ) + mock_instance.save.assert_called_once() + assert result is mock_instance + + def test_update_data_creates_new_chunk_no_save(self, mock_context): + """Test that a new chunk is created but NOT saved when save=False.""" + text = "New unique chunk content" + embedding = [0.1, 0.2, 0.3] - result = Chunk.update_data( - text=text, content_object=mock_message, embedding=embedding, save=True - ) + with patch("apps.ai.models.chunk.Chunk") as mock_chunk_class: + mock_chunk_class.objects.filter.return_value.exists.return_value = False + mock_instance = Mock(spec=Chunk) + mock_chunk_class.return_value = mock_instance + + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=False + ) + + mock_chunk_class.objects.filter.assert_called_once_with( + context__entity_type=mock_context.entity_type, + context__entity_id=mock_context.entity_id, + text=text, + ) + mock_chunk_class.assert_called_once_with( + text=text, embedding=embedding, context=mock_context + ) + mock_instance.save.assert_not_called() + assert result is mock_instance + + def test_update_data_returns_none_if_chunk_already_exists(self, mock_context): + """Test that update_data returns None when a chunk with the same text.""" + text = "Existing chunk content" + embedding = [0.1, 0.2, 0.3] - mock_get_for_model.assert_called_once_with(mock_message) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=mock_message.id, text=text - ) - assert result is None + with patch("apps.ai.models.chunk.Chunk") as mock_chunk_class: + mock_chunk_class.objects.filter.return_value.exists.return_value = True - @patch("apps.ai.models.chunk.Chunk.save") - @patch("apps.ai.models.chunk.Chunk.__init__") - def test_update_data_no_save(self, mock_init, mock_save, mocker): - mock_init.return_value = None + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=True + ) - mock_message = create_model_mock(Message) - text = "Test chunk content" - embedding = [0.1, 0.2, 0.3] + mock_chunk_class.objects.filter.assert_called_once_with( + context__entity_type=mock_context.entity_type, + context__entity_id=mock_context.entity_id, + text=text, + ) - mock_content_type = Mock(spec=ContentType) - mock_get_for_model = mocker.patch( - "django.contrib.contenttypes.models.ContentType.objects.get_for_model", - return_value=mock_content_type, - ) - - mock_filter = mocker.patch( - "apps.ai.models.chunk.Chunk.objects.filter", - return_value=Mock(exists=Mock(return_value=False)), - ) - - result = Chunk.update_data( - text=text, content_object=mock_message, embedding=embedding, save=False - ) - - mock_get_for_model.assert_called_once_with(mock_message) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=mock_message.id, text=text - ) - mock_init.assert_called_once_with( - content_type=mock_content_type, - object_id=mock_message.id, - text=text, - embedding=embedding, - ) - mock_save.assert_not_called() - - assert result is not None - assert isinstance(result, Chunk) + mock_chunk_class.assert_not_called() + assert result is None def test_meta_class_attributes(self): assert Chunk._meta.db_table == "ai_chunks" assert Chunk._meta.verbose_name == "Chunk" - assert ("content_type", "object_id", "text") in Chunk._meta.unique_together + assert ("context", "text") in Chunk._meta.unique_together - def test_generic_foreign_key_relationship(self): - content_type_field = Chunk._meta.get_field("content_type") - object_id_field = Chunk._meta.get_field("object_id") + def test_context_relationship(self): + context_field = Chunk._meta.get_field("context") + from apps.ai.models.context import Context - assert isinstance(content_type_field, models.ForeignKey) - assert content_type_field.remote_field.model == ContentType - assert isinstance(object_id_field, models.PositiveIntegerField) + assert context_field.related_model == Context diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py new file mode 100644 index 0000000000..d98f5061e3 --- /dev/null +++ b/backend/tests/apps/ai/models/context_test.py @@ -0,0 +1,376 @@ +"""Unit tests for AI app context model.""" + +from unittest.mock import Mock, PropertyMock, patch + +import pytest +from django.contrib.contenttypes.models import ContentType + +from apps.ai.models.context import Context +from apps.common.models import TimestampedModel + + +def create_model_mock(model_class): + mock = Mock(spec=model_class) + mock._state = Mock() + mock.pk = 1 + mock.id = 1 + mock.chunks = Mock() + mock.chunks.count.return_value = 0 + return mock + + +class TestContextModel: + def test_meta_class_attributes(self): + assert Context._meta.db_table == "ai_contexts" + assert Context._meta.verbose_name == "Context" + + def test_content_field_properties(self): + field = Context._meta.get_field("content") + assert field.verbose_name == "Generated Text" + assert field.__class__.__name__ == "TextField" + + def test_content_type_field_properties(self): + field = Context._meta.get_field("entity_type") + assert field.null is False + assert field.blank is False + assert hasattr(field, "remote_field") + assert field.remote_field.on_delete.__name__ == "CASCADE" + + def test_object_id_field_properties(self): + field = Context._meta.get_field("entity_id") + assert field.__class__.__name__ == "PositiveIntegerField" + + def test_source_field_properties(self): + field = Context._meta.get_field("source") + assert field.max_length == 100 + assert field.blank is True + assert field.default == "" + + @patch("apps.ai.models.context.Context.save") + @patch("apps.ai.models.context.Context.__init__") + def test_context_creation_with_save(self, mock_init, mock_save): + mock_init.return_value = None + + content = "Test generated text" + source = "test_source" + + context = Context(content=content, source=source) + context.save() + + mock_save.assert_called_once() + + def test_context_inheritance_from_timestamped_model(self): + assert issubclass(Context, TimestampedModel) + + @patch("apps.ai.models.context.Context.objects.create") + def test_context_manager_create(self, mock_create): + mock_context = create_model_mock(Context) + mock_create.return_value = mock_context + + content = "Test text" + source = "test_source" + + result = Context.objects.create(content=content, source=source) + + mock_create.assert_called_once_with(content=content, source=source) + assert result == mock_context + + @patch("apps.ai.models.context.Context.objects.filter") + def test_context_manager_filter(self, mock_filter): + mock_queryset = Mock() + mock_filter.return_value = mock_queryset + + result = Context.objects.filter(source="test_source") + + mock_filter.assert_called_once_with(source="test_source") + assert result == mock_queryset + + @patch("apps.ai.models.context.Context.objects.get") + def test_context_manager_get(self, mock_get): + mock_context = create_model_mock(Context) + mock_get.return_value = mock_context + + result = Context.objects.get(id=1) + + mock_get.assert_called_once_with(id=1) + assert result == mock_context + + @patch("apps.ai.models.context.Context.full_clean") + def test_context_validation(self, mock_full_clean): + context = Context() + context.content = "Valid text" + context.source = "A" * 100 + + context.full_clean() + + mock_full_clean.assert_called_once() + + @patch("apps.ai.models.context.Context.full_clean") + def test_context_validation_source_too_long(self, mock_full_clean): + from django.core.exceptions import ValidationError + + mock_full_clean.side_effect = ValidationError("Source too long") + + context = Context() + context.content = "Valid text" + context.source = "A" * 101 + + with pytest.raises(ValidationError) as exc_info: + context.full_clean() + assert "Source too long" in str(exc_info.value) + + def test_context_default_values(self): + context = Context() + + assert context.source == "" + + @patch("apps.ai.models.context.Context.refresh_from_db") + def test_context_refresh_from_db(self, mock_refresh): + context = Context() + context.refresh_from_db() + + mock_refresh.assert_called_once() + + @patch("apps.ai.models.context.Context.delete") + def test_context_delete(self, mock_delete): + context = Context() + context.delete() + + mock_delete.assert_called_once() + + @patch("apps.ai.models.context.Context.objects.get") + def test_update_data_existing_context(self, mock_get): + mock_context = create_model_mock(Context) + mock_get.return_value = mock_context + + content = "Test" + mock_content_object = Mock() + mock_content_object.pk = 1 + + with patch( + "apps.ai.models.context.ContentType.objects.get_for_model" + ) as mock_get_for_model: + mock_content_type = Mock() + mock_content_type.get_source_expressions = Mock(return_value=[]) + mock_get_for_model.return_value = mock_content_type + + result = Context.update_data(content, mock_content_object, source="src", save=True) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_get.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + assert result == mock_context + assert mock_context.content == content + assert mock_context.source == "src" + + def test_str_method_with_name_attribute(self): + """Test __str__ method when entity has name attribute.""" + content_object = Mock() + content_object.name = "Test Object" + + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = ( + "This is test content that is longer than 50 characters to test truncation" + ) + context.entity_type_id = 1 + context.entity_id = "123" + + with ( + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), + ): + result = str(context) + assert ( + result + == "test_model Test Object: This is test content that is longer than 50 cha..." + ) + + def test_str_method_with_key_attribute(self): + """Test __str__ method when entity has key but no name attribute.""" + content_object = Mock() + content_object.name = None + content_object.key = "test-key" + + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = "Short content" + context.entity_type_id = 1 + context.entity_id = "123" + + with ( + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), + ): + result = str(context) + assert result == "test_model test-key: Short content" + + def test_str_method_with_neither_name_nor_key(self): + """Test __str__ method when entity has neither name nor key attribute.""" + content_object = Mock() + content_object.name = None + content_object.key = None + content_object.__str__ = Mock(return_value="Unknown") + + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = "Another test content" + context.entity_type_id = 1 + context.entity_id = "456" + + with ( + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), + ): + result = str(context) + assert result == "test_model Unknown: Another test content" + + def test_str_method_fallback_to_str(self): + """Test __str__ method falls back to str(entity).""" + content_object = Mock() + content_object.name = None + content_object.key = None + content_object.__str__ = Mock(return_value="String representation") + + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = "Test content" + context.entity_type_id = 1 + context.entity_id = "123" + + with ( + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), + ): + result = str(context) + assert result == "test_model String representation: Test content" + + @patch("apps.ai.models.context.Context.objects.get") + @patch("apps.ai.models.context.Context.__init__") + def test_update_data_new_context_with_save(self, mock_init, mock_get): + """Test update_data creating a new context with save=True.""" + from apps.ai.models.context import Context as ContextModel + + mock_get.side_effect = ContextModel.DoesNotExist + mock_init.return_value = None # __init__ should return None + + content = "New test content" + mock_content_object = Mock() + mock_content_object.pk = 1 + source = "test_source" + + with patch( + "apps.ai.models.context.ContentType.objects.get_for_model" + ) as mock_get_for_model: + mock_content_type = Mock(spec=ContentType) + mock_content_type.get_source_expressions = Mock(return_value=[]) + mock_get_for_model.return_value = mock_content_type + + # Mock the context instance and its save method + with patch.object(ContextModel, "save") as mock_save: + result = Context.update_data( + content, mock_content_object, source=source, save=True + ) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_get.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + mock_init.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + assert result.content == content + assert result.source == source + mock_save.assert_called_once() + + @patch("apps.ai.models.context.Context.objects.get") + @patch("apps.ai.models.context.Context.__init__") + def test_update_data_new_context_without_save(self, mock_init, mock_get): + """Test update_data creating a new context with save=False.""" + from apps.ai.models.context import Context as ContextModel + + mock_get.side_effect = ContextModel.DoesNotExist + mock_init.return_value = None # __init__ should return None + + content = "New test content" + mock_content_object = Mock() + mock_content_object.pk = 1 + source = "test_source" + + with patch( + "apps.ai.models.context.ContentType.objects.get_for_model" + ) as mock_get_for_model: + mock_content_type = Mock(spec=ContentType) + mock_content_type.get_source_expressions = Mock(return_value=[]) + mock_get_for_model.return_value = mock_content_type + + # Mock the context instance and its save method + with patch.object(ContextModel, "save") as mock_save: + result = Context.update_data( + content, mock_content_object, source=source, save=False + ) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_get.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + mock_init.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + assert result.content == content + assert result.source == source + mock_save.assert_not_called() diff --git a/backend/tests/apps/common/open_ai_test.py b/backend/tests/apps/common/open_ai_test.py index ee8827bc19..c4fface6c7 100644 --- a/backend/tests/apps/common/open_ai_test.py +++ b/backend/tests/apps/common/open_ai_test.py @@ -60,6 +60,7 @@ def test_complete_general_exception(self, mock_openai, mock_logger, openai_insta response = openai_instance.complete() assert response is None + mock_logger.exception.assert_called_once_with( "An error occurred during OpenAI API request." ) diff --git a/backend/tests/apps/slack/management/__init__.py b/backend/tests/apps/slack/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2