diff --git a/backend/apps/common/open_ai.py b/backend/apps/common/open_ai.py index 2488da2f72..81730b8355 100644 --- a/backend/apps/common/open_ai.py +++ b/backend/apps/common/open_ai.py @@ -2,6 +2,7 @@ from __future__ import annotations +import base64 import logging import openai @@ -32,6 +33,8 @@ def __init__( self.max_tokens = max_tokens self.model = model self.temperature = temperature + self.image_data: bytes | None = None + self.image_mime_type: str | None = None def set_input(self, content: str) -> OpenAi: """Set system role content. @@ -75,6 +78,22 @@ def set_prompt(self, content: str) -> OpenAi: return self + def set_image(self, image_data: bytes, mime_type: str) -> OpenAi: + """Set image data for vision API. + + Args: + image_data (bytes): Raw image bytes. + mime_type (str): MIME type of the image (e.g., "image/png"). + + Returns: + OpenAi: The current instance. + + """ + self.image_data = image_data + self.image_mime_type = mime_type + + return self + def complete(self) -> str | None: """Get API response. @@ -87,11 +106,27 @@ def complete(self) -> str | None: """ try: + # Build user message content + user_content: str | list[dict[str, object]] + if self.image_data and self.image_mime_type: + # Vision API with image + base64_image = base64.b64encode(self.image_data).decode("utf-8") + user_content = [ + {"type": "text", "text": self.input}, + { + "type": "image_url", + "image_url": {"url": f"data:{self.image_mime_type};base64,{base64_image}"}, + }, + ] + else: + # Text-only + user_content = self.input + response = self.client.chat.completions.create( max_tokens=self.max_tokens, messages=[ {"role": "system", "content": self.prompt}, - {"role": "user", "content": self.input}, + {"role": "user", "content": user_content}, ], model=self.model, temperature=self.temperature, diff --git a/backend/apps/slack/events/message_posted.py b/backend/apps/slack/events/message_posted.py index 5b38c3077b..c4821e0da6 100644 --- a/backend/apps/slack/events/message_posted.py +++ b/backend/apps/slack/events/message_posted.py @@ -25,18 +25,27 @@ def __init__(self): def handle_event(self, event, client): """Handle an incoming message event.""" - if event.get("subtype") or event.get("bot_id"): - logger.info("Ignored message due to subtype, bot_id, or thread_ts.") + # Ignore bot messages + if event.get("bot_id"): + logger.info("Ignoring bot message.") return + # Allow file_share subtype (messages with images), ignore others + if event.get("subtype") and event.get("subtype") != "file_share": + logger.info("Ignoring message with subtype: %s", event.get("subtype")) + return + + # Update parent message if this is a thread reply if event.get("thread_ts"): - try: - Message.objects.filter( - slack_message_id=event.get("thread_ts"), - conversation__slack_channel_id=event.get("channel"), - ).update(has_replies=True) - except Message.DoesNotExist: - logger.warning("Thread message not found.") + updated = Message.objects.filter( + slack_message_id=event.get("thread_ts"), + conversation__slack_channel_id=event.get("channel"), + ).update(has_replies=True) + if not updated: + logger.info( + "Parent message for thread_ts %s not found in thread reply.", + event.get("thread_ts"), + ) return channel_id = event.get("channel") @@ -49,25 +58,63 @@ def handle_event(self, event, client): is_nest_bot_assistant_enabled=True, ) except Conversation.DoesNotExist: - logger.warning("Conversation not found or assistant not enabled.") + logger.info("Conversation not found or bot not enabled for channel: %s", channel_id) return - if not self.question_detector.is_owasp_question(text): + # Check if message has valid images - only bypass question detector for valid images + from apps.slack.services.image_extraction import ( + extract_images_then_maybe_reply, + is_valid_image_file, + ) + + image_files = [ + f + for f in event.get("files", []) + if f.get("mimetype", "").startswith("image/") and is_valid_image_file(f) + ][:3] + + # For text-only messages or messages without valid images, use question detector + if not image_files and not self.question_detector.is_owasp_question(text): + logger.info("Question detector rejected message") return try: - author = Member.objects.get(slack_user_id=user_id, workspace=conversation.workspace) + author = Member.objects.get( + slack_user_id=user_id, + workspace=conversation.workspace, + ) except Member.DoesNotExist: user_info = client.users_info(user=user_id) - author = Member.update_data(user_info["user"], conversation.workspace, save=True) - logger.info("Created new member") + author = Member.update_data( + user_info["user"], + conversation.workspace, + save=True, + ) + logger.info("Created new member for user_id %s", user_id) message = Message.update_data( - data=event, conversation=conversation, author=author, save=True + data=event, + conversation=conversation, + author=author, + save=True, ) - django_rq.get_queue("ai").enqueue_in( - timedelta(minutes=QUEUE_RESPONSE_TIME_MINUTES), - generate_ai_reply_if_unanswered, - message.id, - ) + # Handle messages with valid images + if image_files: + logger.info( + "Queueing image extraction for message %s with %s image(s)", + message.id, + len(image_files), + ) + django_rq.get_queue("ai").enqueue( + extract_images_then_maybe_reply, + message.id, + image_files, + ) + else: + logger.info("Queueing AI reply for message %s", message.id) + django_rq.get_queue("ai").enqueue_in( + timedelta(minutes=QUEUE_RESPONSE_TIME_MINUTES), + generate_ai_reply_if_unanswered, + message.id, + ) diff --git a/backend/apps/slack/models/message.py b/backend/apps/slack/models/message.py index dd06a7e991..07d01cacd4 100644 --- a/backend/apps/slack/models/message.py +++ b/backend/apps/slack/models/message.py @@ -84,6 +84,20 @@ def text(self) -> str: """Get the text of the message.""" return self.raw_data.get("text", "") + @property + def text_with_images(self) -> str: + """Get message text combined with extracted image text.""" + parts = [self.text] if self.text else [] + + if extractions := self.raw_data.get("image_extractions", []): + parts.extend( + f"\n[Image: {extraction.get('file_name', 'unnamed')}]\n" + f"{extraction['extracted_text']}" + for extraction in extractions + if extraction.get("status") == "success" and extraction.get("extracted_text") + ) + return "\n\n".join(parts) + @property def ts(self) -> str: """Get the message timestamp.""" diff --git a/backend/apps/slack/services/image_extraction.py b/backend/apps/slack/services/image_extraction.py new file mode 100644 index 0000000000..c94db2b2f6 --- /dev/null +++ b/backend/apps/slack/services/image_extraction.py @@ -0,0 +1,274 @@ +"""Image extraction orchestration for Slack messages.""" + +import io +import logging +from datetime import UTC, datetime, timedelta + +import requests +from django_rq import job +from PIL import Image +from requests.exceptions import RequestException + +from apps.ai.common.constants import QUEUE_RESPONSE_TIME_MINUTES +from apps.common.open_ai import OpenAi +from apps.slack.models import Message +from apps.slack.services.message_auto_reply import generate_ai_reply_if_unanswered + +logger = logging.getLogger(__name__) + +MAX_IMAGES_PER_MESSAGE = 3 +SUPPORTED_IMAGE_TYPES = ("image/png", "image/jpeg", "image/jpg", "image/gif", "image/webp") +MAX_IMAGE_SIZE_MB = 20 + +# OpenAI supported formats +OPENAI_SUPPORTED_FORMATS = {"PNG", "JPEG", "GIF", "WEBP"} +# Map PIL formats to MIME types +FORMAT_TO_MIME = { + "PNG": "image/png", + "JPEG": "image/jpeg", + "JPG": "image/jpeg", + "GIF": "image/gif", + "WEBP": "image/webp", +} + + +def is_valid_image_file(file_data: dict) -> bool: + """Check if file is a valid image for extraction. + + Args: + file_data: Dictionary containing file metadata from Slack + + Returns: + True if file is valid for extraction, False otherwise + + """ + mimetype = file_data.get("mimetype", "") + size_bytes = file_data.get("size", 0) + + if mimetype not in SUPPORTED_IMAGE_TYPES: + logger.debug("Skipping non-image file: %s", mimetype) + return False + + size_mb = size_bytes / (1024 * 1024) + if size_mb > MAX_IMAGE_SIZE_MB: + logger.warning( + "Image %s too large: %.2fMB (max: %sMB)", + file_data.get("id"), + size_mb, + MAX_IMAGE_SIZE_MB, + ) + return False + + return True + + +def download_slack_image(url: str, bot_token: str) -> bytes | None: + """Download image from Slack using bot token authentication. + + Args: + url: Slack file URL (url_private or url_private_download) + bot_token: Slack bot token for authentication + + Returns: + Image bytes if successful, None if download fails + + """ + try: + headers = {"Authorization": f"Bearer {bot_token}"} + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + except RequestException: + logger.exception("Failed to download image from %s", url) + return None + else: + return response.content + + +def validate_and_get_image_format(image_data: bytes) -> str | None: + """Validate image and return MIME type if supported by OpenAI. + + Args: + image_data: Raw image bytes + + Returns: + MIME type string if valid, None if unsupported or invalid + + """ + try: + img = Image.open(io.BytesIO(image_data)) + img.verify() # Verify it's actually a valid image + + # Re-open after verify (verify() closes the file) + img = Image.open(io.BytesIO(image_data)) + image_format = img.format + + if not image_format: + logger.warning("Could not detect image format from data") + return None + + # Check if format is supported by OpenAI + if image_format.upper() not in OPENAI_SUPPORTED_FORMATS: + logger.warning( + "Unsupported image format: %s (supported: %s)", + image_format, + ", ".join(OPENAI_SUPPORTED_FORMATS), + ) + return None + + # Return proper MIME type + mime_type = FORMAT_TO_MIME.get(image_format.upper()) + if not mime_type: + logger.warning("No MIME type mapping for format: %s", image_format) + return None + + except (OSError, ValueError) as e: + logger.warning("Failed to validate image: %s", e) + return None + except Exception: + logger.exception("Unexpected error validating image") + return None + else: + logger.debug("Detected valid image format: %s -> %s", image_format, mime_type) + return mime_type + + +def extract_text_from_image(image_data: bytes) -> str: + """Extract text from image using GPT-4o Vision API. + + Args: + image_data: Raw image bytes + + Returns: + Extracted text from the image + + Raises: + ValueError: If image format is invalid or unsupported + + """ + # Validate image format first + mime_type = validate_and_get_image_format(image_data) + if not mime_type: + msg = ( + "Invalid or unsupported image format. " + f"Supported formats: {', '.join(OPENAI_SUPPORTED_FORMATS)}" + ) + raise ValueError(msg) + + # Use the OpenAi class with vision support + open_ai = OpenAi(model="gpt-4o", max_tokens=1000, temperature=0.0) + + prompt = ( + "Extract all text from this image. " + "If there's code, preserve formatting. " + "If it's a screenshot of an error, include the full error message. " + "If it's a diagram or chart, describe the key information. " + "Return only the extracted text without any preamble." + ) + + result = ( + open_ai.set_prompt("You are a helpful assistant that extracts text from images.") + .set_input(prompt) + .set_image(image_data, mime_type) + .complete() + ) + + # Handle None response + if result is None: + logger.warning("OpenAI returned None content for image extraction") + return "" + + return result.strip() + + +@job("ai") +def extract_images_then_maybe_reply(message_id: int, image_files: list[dict]) -> None: + """Extract text from images, store results, then queue AI reply. + + Args: + message_id: Primary key of the Slack message + image_files: List of file metadata dictionaries from Slack + + """ + import django_rq + + try: + message = Message.objects.get(pk=message_id) + except Message.DoesNotExist: + logger.error("Message %s not found for image extraction", message_id) # noqa: TRY400 + return + + logger.info("Extracting text from %s images for message %s", len(image_files), message_id) + + bot_token = message.conversation.workspace.bot_token + if not bot_token or bot_token == "None": # noqa: S105 + logger.error("No bot token available") + django_rq.get_queue("ai").enqueue_in( + timedelta(minutes=QUEUE_RESPONSE_TIME_MINUTES), + generate_ai_reply_if_unanswered, + message_id, + ) + return + + extractions = [] + for idx, file_data in enumerate(image_files[:MAX_IMAGES_PER_MESSAGE]): + file_id = file_data.get("id") + file_name = file_data.get("name", f"image_{idx + 1}") + + extraction_result = { + "file_id": file_id, + "file_name": file_name, + "timestamp": datetime.now(UTC).isoformat(), + "model": "gpt-4o", + } + + try: + image_url = file_data.get("url_private_download") or file_data.get("url_private") + if not image_url: + extraction_result.update({"status": "failed", "error": "No image URL found"}) + extractions.append(extraction_result) + continue + + image_data = download_slack_image(image_url, bot_token) + if not image_data: + extraction_result.update({"status": "failed", "error": "Failed to download image"}) + extractions.append(extraction_result) + continue + + extracted_text = extract_text_from_image(image_data) + + extraction_result.update( + { + "status": "success", + "extracted_text": extracted_text, + "size_bytes": len(image_data), + } + ) + + logger.info( + "Successfully extracted text from %s (%s chars)", file_name, len(extracted_text) + ) + + except ValueError as e: + # Image format validation error from ai service + logger.warning("Invalid image format for %s: %s", file_name, e) + extraction_result.update({"status": "failed", "error": f"Invalid image format: {e}"}) + + except Exception: + logger.exception("Failed to extract text from %s", file_name) + extraction_result.update({"status": "failed", "error": "Extraction failed"}) + + extractions.append(extraction_result) + + if "image_extractions" not in message.raw_data: + message.raw_data["image_extractions"] = [] + message.raw_data["image_extractions"].extend(extractions) + message.save(update_fields=["raw_data"]) + + success_count = sum(1 for e in extractions if e.get("status") == "success") + logger.info("Completed image extraction: %s/%s successful", success_count, len(extractions)) + + django_rq.get_queue("ai").enqueue_in( + timedelta(minutes=QUEUE_RESPONSE_TIME_MINUTES), + generate_ai_reply_if_unanswered, + message_id, + ) diff --git a/backend/apps/slack/services/message_auto_reply.py b/backend/apps/slack/services/message_auto_reply.py index 2e941ed8fe..1f914e27f7 100644 --- a/backend/apps/slack/services/message_auto_reply.py +++ b/backend/apps/slack/services/message_auto_reply.py @@ -42,7 +42,8 @@ def generate_ai_reply_if_unanswered(message_id: int): logger.exception("Error checking for replies for message") channel_id = message.conversation.slack_channel_id - ai_response_text = process_ai_query(query=message.text, channel_id=channel_id) + query_text = message.text_with_images if message.text_with_images else message.text + ai_response_text = process_ai_query(query=query_text, channel_id=channel_id) if not ai_response_text: # Add shrugging reaction when no answer can be generated try: diff --git a/backend/tests/apps/slack/common/handlers/ai_test.py b/backend/tests/apps/slack/common/handlers/ai_test.py index aa218fa290..e675f27382 100644 --- a/backend/tests/apps/slack/common/handlers/ai_test.py +++ b/backend/tests/apps/slack/common/handlers/ai_test.py @@ -29,7 +29,9 @@ def test_get_blocks_with_successful_response(self, mock_markdown, mock_process_a result = get_blocks(query) - mock_process_ai_query.assert_called_once_with(query.strip()) + mock_process_ai_query.assert_called_once_with( + query.strip(), channel_id=None, is_app_mention=False + ) mock_markdown.assert_called_once_with(ai_response) assert result == [expected_block] @@ -45,7 +47,9 @@ def test_get_blocks_with_no_response(self, mock_get_error_blocks, mock_process_a result = get_blocks(query) - mock_process_ai_query.assert_called_once_with(query.strip()) + mock_process_ai_query.assert_called_once_with( + query.strip(), channel_id=None, is_app_mention=False + ) mock_get_error_blocks.assert_called_once() assert result == error_blocks @@ -61,7 +65,9 @@ def test_get_blocks_with_empty_response(self, mock_get_error_blocks, mock_proces result = get_blocks(query) - mock_process_ai_query.assert_called_once_with(query.strip()) + mock_process_ai_query.assert_called_once_with( + query.strip(), channel_id=None, is_app_mention=False + ) mock_get_error_blocks.assert_called_once() assert result == error_blocks @@ -75,7 +81,7 @@ def test_process_ai_query_success(self, mock_process_query): result = process_ai_query(query) - mock_process_query.assert_called_once_with(query) + mock_process_query.assert_called_once_with(query, channel_id=None, is_app_mention=False) assert result == expected_response @patch("apps.slack.common.handlers.ai.process_query") @@ -87,7 +93,7 @@ def test_process_ai_query_failure(self, mock_process_query): result = process_ai_query(query) - mock_process_query.assert_called_once_with(query) + mock_process_query.assert_called_once_with(query, channel_id=None, is_app_mention=False) assert result is None @patch("apps.slack.common.handlers.ai.process_query") @@ -99,7 +105,7 @@ def test_process_ai_query_returns_none(self, mock_process_query): result = process_ai_query(query) - mock_process_query.assert_called_once_with(query) + mock_process_query.assert_called_once_with(query, channel_id=None, is_app_mention=False) assert result is None @patch("apps.slack.common.handlers.ai.process_query") @@ -111,7 +117,7 @@ def test_process_ai_query_non_owasp_question(self, mock_process_query): result = process_ai_query(query) - mock_process_query.assert_called_once_with(query) + mock_process_query.assert_called_once_with(query, channel_id=None, is_app_mention=False) assert result == get_default_response() @patch("apps.slack.common.handlers.ai.markdown") @@ -142,4 +148,6 @@ def test_get_blocks_strips_whitespace(self): query_with_whitespace = " What is OWASP? " get_blocks(query_with_whitespace) - mock_process_ai_query.assert_called_once_with("What is OWASP?") + mock_process_ai_query.assert_called_once_with( + "What is OWASP?", channel_id=None, is_app_mention=False + ) diff --git a/backend/tests/apps/slack/common/question_detector_test.py b/backend/tests/apps/slack/common/question_detector_test.py index 9ecc1b91aa..3fa017ee6f 100644 --- a/backend/tests/apps/slack/common/question_detector_test.py +++ b/backend/tests/apps/slack/common/question_detector_test.py @@ -24,11 +24,20 @@ def _mock_openai(self, monkeypatch): monkeypatch.setattr("openai.OpenAI", MagicMock(return_value=mock_client)) - # Mock the Retriever class - mock_retriever = MagicMock() - mock_retriever.retrieve.return_value = [] + # Mock embedder and retrieval path used by QuestionDetector + mock_embedder = MagicMock() + mock_embedder.embed_query.return_value = [0.1, 0.2, 0.3] + + def _mock_retrieve_chunks(*_args, **_kwargs): + return [] + + monkeypatch.setattr( + "apps.slack.common.question_detector.get_embedder", + lambda: mock_embedder, + ) monkeypatch.setattr( - "apps.slack.common.question_detector.Retriever", MagicMock(return_value=mock_retriever) + "apps.slack.common.question_detector.QuestionDetector._retrieve_chunks", + _mock_retrieve_chunks, ) monkeypatch.setattr( @@ -70,7 +79,7 @@ def test_init(self, detector): # Test that detector initializes properly assert detector is not None assert hasattr(detector, "openai_client") - assert hasattr(detector, "retriever") + assert hasattr(detector, "embedder") def test_is_owasp_question_true_cases(self, detector, sample_context_chunks, monkeypatch): """Test cases that should be detected as OWASP questions.""" @@ -178,7 +187,7 @@ def test_mocked_initialization(self): # Test that detector initializes properly assert detector is not None assert hasattr(detector, "openai_client") - assert hasattr(detector, "retriever") + assert hasattr(detector, "embedder") def test_class_constants(self, detector): """Test that class constants are properly defined.""" diff --git a/backend/tests/apps/slack/events/app_mention_test.py b/backend/tests/apps/slack/events/app_mention_test.py index dbd7c35c8a..e8542701ec 100644 --- a/backend/tests/apps/slack/events/app_mention_test.py +++ b/backend/tests/apps/slack/events/app_mention_test.py @@ -91,7 +91,9 @@ def test_handle_event_success(self, mock_get_blocks, mock_conversation, handler, timestamp="1234567890.123456", name="eyes", ) - mock_get_blocks.assert_called_once_with(query="What is OWASP?") + mock_get_blocks.assert_called_once_with( + query="What is OWASP?", channel_id="C123456", is_app_mention=True + ) mock_client.chat_postMessage.assert_called_once_with( channel="C123456", blocks=[{"type": "section", "text": {"text": "Response"}}], @@ -154,7 +156,9 @@ def test_handle_event_extract_query_from_blocks( handler.handle_event(event, mock_client) - mock_get_blocks.assert_called_once_with(query="What is OWASP?") + mock_get_blocks.assert_called_once_with( + query="What is OWASP?", channel_id="C123456", is_app_mention=True + ) @patch("apps.slack.events.app_mention.Conversation") @patch("apps.slack.events.app_mention.get_blocks") @@ -187,7 +191,9 @@ def test_handle_event_extract_query_from_blocks_multiple_elements( handler.handle_event(event, mock_client) - mock_get_blocks.assert_called_once_with(query="What is OWASP?") + mock_get_blocks.assert_called_once_with( + query="What is OWASP?", channel_id="C123456", is_app_mention=True + ) @patch("apps.slack.events.app_mention.Conversation") @patch("apps.slack.events.app_mention.get_blocks") @@ -212,7 +218,9 @@ def test_handle_event_blocks_not_rich_text( handler.handle_event(event, mock_client) - mock_get_blocks.assert_called_once_with(query="What is OWASP?") + mock_get_blocks.assert_called_once_with( + query="What is OWASP?", channel_id="C123456", is_app_mention=True + ) @patch("apps.slack.events.app_mention.Conversation") @patch("apps.slack.events.app_mention.get_blocks") diff --git a/backend/tests/apps/slack/services/message_auto_reply_test.py b/backend/tests/apps/slack/services/message_auto_reply_test.py index 898c52dba2..aa6028ea91 100644 --- a/backend/tests/apps/slack/services/message_auto_reply_test.py +++ b/backend/tests/apps/slack/services/message_auto_reply_test.py @@ -39,6 +39,7 @@ def mock_message(self, mock_conversation): message.id = 1 message.slack_message_id = "1234567890.123456" message.text = "What is OWASP?" + message.text_with_images = None message.conversation = mock_conversation return message @@ -76,8 +77,14 @@ def test_generate_ai_reply_success( ts=mock_message.slack_message_id, limit=1, ) - mock_process_ai_query.assert_called_once_with(query=mock_message.text) - mock_get_blocks.assert_called_once_with("OWASP is a security organization...") + mock_process_ai_query.assert_called_once_with( + query=mock_message.text, + channel_id=mock_message.conversation.slack_channel_id, + ) + mock_get_blocks.assert_called_once_with( + "OWASP is a security organization...", + channel_id=mock_message.conversation.slack_channel_id, + ) mock_client.chat_postMessage.assert_called_once_with( channel=mock_message.conversation.slack_channel_id, blocks=[