diff --git a/tests/multimodal/media/test_image.py b/tests/multimodal/media/test_image.py index 065a40d68e35..65196d7805cc 100644 --- a/tests/multimodal/media/test_image.py +++ b/tests/multimodal/media/test_image.py @@ -131,3 +131,77 @@ def test_image_media_io_rgba_background_color_validation(): ImageMediaIO(rgba_background_color=(0, 0, 0)) # Should not raise ImageMediaIO(rgba_background_color=[255, 255, 255]) # Should not raise ImageMediaIO(rgba_background_color=(128, 128, 128)) # Should not raise + + +def test_image_media_io_load_bytes(tmp_path): + """Test load_bytes with valid and invalid image data.""" + # Save a valid RGB image to use as source bytes + valid_image = Image.new("RGB", (8, 8), (100, 150, 200)) + valid_path = tmp_path / "valid.png" + valid_image.save(valid_path) + + valid_data = valid_path.read_bytes() + + # Test 1: Valid image bytes load successfully and are fully decoded + image_io = ImageMediaIO() + result = image_io.load_bytes(valid_data) + + # Check the returned media is a properly loaded image + assert isinstance(result.media, Image.Image) + assert result.media.size == (8, 8) + assert result.media.getpixel((0, 0)) == (100, 150, 200) + + # Test 2: Garbage bytes raise ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(b"not an image") + + # Test 3: Truncated PNG header raises ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10) + + # Test 4: Real PNG truncated mid-stream raises ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(valid_data[: len(valid_data) // 2]) + + # Test 5: Empty bytes raise ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(b"") + + +def test_image_media_io_load_file(tmp_path): + """Test load_file with valid and invalid image files.""" + # Save a valid RGB image to disk + valid_image = Image.new("RGB", (4, 4), (10, 20, 30)) + valid_path = tmp_path / "valid.png" + valid_image.save(valid_path) + + # Test 1: Valid image file loads successfully and is fully decoded + image_io = ImageMediaIO() + result = image_io.load_file(valid_path) + + # Check the returned media is a properly loaded image + assert isinstance(result.media, Image.Image) + assert result.media.size == (4, 4) + assert result.media.getpixel((0, 0)) == (10, 20, 30) + + # Test 2: File with garbage content raises ValueError + bad_file = tmp_path / "bad.png" + bad_file.write_bytes(b"this is not an image") + + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_file(bad_file) + + # Test 3: File with truncated PNG header raises ValueError + truncated_file = tmp_path / "truncated.png" + truncated_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10) + + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_file(truncated_file) + + # Test 4: Real PNG file truncated mid-stream raises ValueError + valid_data = valid_path.read_bytes() + truncated_real_file = tmp_path / "truncated_real.png" + truncated_real_file.write_bytes(valid_data[: len(valid_data) // 2]) + + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_file(truncated_real_file) diff --git a/vllm/multimodal/media/image.py b/vllm/multimodal/media/image.py index ea4bf7b01527..ea816b760fea 100644 --- a/vllm/multimodal/media/image.py +++ b/vllm/multimodal/media/image.py @@ -68,17 +68,19 @@ def _convert_image_mode( return convert_image_mode(image, self.image_mode) def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]: - image = Image.open(BytesIO(data)) - return MediaWithBytes(self._convert_image_mode(image), data) + try: + image = Image.open(BytesIO(data)) + image.load() + image = self._convert_image_mode(image) + except (OSError, Image.UnidentifiedImageError) as e: + raise ValueError(f"Failed to load image: {e}") from e + return MediaWithBytes(image, data) def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]: - with open(filepath, "rb") as f: - data = f.read() - image = Image.open(BytesIO(data)) - return MediaWithBytes(self._convert_image_mode(image), data) + return self.load_bytes(filepath.read_bytes()) def encode_base64( self,