From d235a0de7493bc5026394c185de53c046dc26765 Mon Sep 17 00:00:00 2001 From: Beibin Li Date: Tue, 2 Apr 2024 09:36:13 -0700 Subject: [PATCH] Update mm test: create dummy image in case file corrput --- .../capabilities/test_vision_capability.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/test/agentchat/contrib/capabilities/test_vision_capability.py b/test/agentchat/contrib/capabilities/test_vision_capability.py index a62d5245057f..eaa1ec9bc251 100644 --- a/test/agentchat/contrib/capabilities/test_vision_capability.py +++ b/test/agentchat/contrib/capabilities/test_vision_capability.py @@ -1,3 +1,4 @@ +import os from unittest.mock import MagicMock, patch import pytest @@ -5,6 +6,8 @@ from autogen.agentchat.conversable_agent import ConversableAgent try: + from PIL import Image + from autogen.agentchat.contrib.capabilities.vision_capability import VisionCapability except ImportError: skip_test = True @@ -21,6 +24,15 @@ def lmm_config(): } +def png_filename() -> str: + filename = "tmp/test_image.png" + if not os.path.exists(filename): + # Setup: Create a PNG file + image = Image.new("RGB", (100, 100), color="blue") + image.save(filename) + return filename # This is what the test will use + + @pytest.fixture def vision_capability(lmm_config): return VisionCapability(lmm_config, custom_caption_func=None) @@ -72,9 +84,9 @@ def test_process_last_received_message_text(mock_lmm_client, vision_capability): def test_process_last_received_message_with_image( mock_get_caption, mock_convert_base64, mock_get_image_data, vision_capability ): - content = [{"type": "image_url", "image_url": {"url": "notebook/viz_gc.png"}}] + content = [{"type": "image_url", "image_url": {"url": (png_filename())}}] expected_caption = ( - " in case you can not see, the caption of this image is: A sample image caption.\n" + f" in case you can not see, the caption of this image is: A sample image caption.\n" ) processed_content = vision_capability.process_last_received_message(content) assert processed_content == expected_caption @@ -101,7 +113,7 @@ def caption_func(image_url: str, image_data=None, lmm_client=None) -> str: class TestCustomCaptionFunc: def test_custom_caption_func_with_valid_url(self, custom_caption_func): """Test custom caption function with a valid image URL.""" - image_url = "notebook/viz_gc.png" + image_url = png_filename() expected_caption = f"An image description. The image is from {image_url}." assert custom_caption_func(image_url) == expected_caption, "Caption does not match expected output." @@ -109,7 +121,7 @@ def test_process_last_received_message_with_custom_func(self, lmm_config, custom """Test processing a message containing an image URL with a custom caption function.""" vision_capability = VisionCapability(lmm_config, custom_caption_func=custom_caption_func) - image_url = "notebook/viz_gc.png" + image_url = png_filename() content = [{"type": "image_url", "image_url": {"url": image_url}}] expected_output = f" An image description. The image is from {image_url}." processed_content = vision_capability.process_last_received_message(content)