Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/multimodal/media/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,77 @@ def test_image_media_io_rgba_background_color_validation():
ImageMediaIO(rgba_background_color=(0, 0, 0)) # Should not raise
ImageMediaIO(rgba_background_color=[255, 255, 255]) # Should not raise
ImageMediaIO(rgba_background_color=(128, 128, 128)) # Should not raise


def test_image_media_io_load_bytes(tmp_path):
"""Test load_bytes with valid and invalid image data."""
# Save a valid RGB image to use as source bytes
valid_image = Image.new("RGB", (8, 8), (100, 150, 200))
valid_path = tmp_path / "valid.png"
valid_image.save(valid_path)

valid_data = valid_path.read_bytes()

# Test 1: Valid image bytes load successfully and are fully decoded
image_io = ImageMediaIO()
result = image_io.load_bytes(valid_data)

# Check the returned media is a properly loaded image
assert isinstance(result.media, Image.Image)
assert result.media.size == (8, 8)
assert result.media.getpixel((0, 0)) == (100, 150, 200)

# Test 2: Garbage bytes raise ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(b"not an image")

# Test 3: Truncated PNG header raises ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10)

# Test 4: Real PNG truncated mid-stream raises ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(valid_data[: len(valid_data) // 2])

# Test 5: Empty bytes raise ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(b"")


def test_image_media_io_load_file(tmp_path):
"""Test load_file with valid and invalid image files."""
# Save a valid RGB image to disk
valid_image = Image.new("RGB", (4, 4), (10, 20, 30))
valid_path = tmp_path / "valid.png"
valid_image.save(valid_path)

# Test 1: Valid image file loads successfully and is fully decoded
image_io = ImageMediaIO()
result = image_io.load_file(valid_path)

# Check the returned media is a properly loaded image
assert isinstance(result.media, Image.Image)
assert result.media.size == (4, 4)
assert result.media.getpixel((0, 0)) == (10, 20, 30)

# Test 2: File with garbage content raises ValueError
bad_file = tmp_path / "bad.png"
bad_file.write_bytes(b"this is not an image")

with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_file(bad_file)

# Test 3: File with truncated PNG header raises ValueError
truncated_file = tmp_path / "truncated.png"
truncated_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10)

with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_file(truncated_file)

# Test 4: Real PNG file truncated mid-stream raises ValueError
valid_data = valid_path.read_bytes()
truncated_real_file = tmp_path / "truncated_real.png"
truncated_real_file.write_bytes(valid_data[: len(valid_data) // 2])

with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_file(truncated_real_file)
14 changes: 8 additions & 6 deletions vllm/multimodal/media/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,19 @@ def _convert_image_mode(
return convert_image_mode(image, self.image_mode)

def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
try:
image = Image.open(BytesIO(data))
image.load()
image = self._convert_image_mode(image)
except (OSError, Image.UnidentifiedImageError) as e:
raise ValueError(f"Failed to load image: {e}") from e
return MediaWithBytes(image, data)

def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
return self.load_bytes(pybase64.b64decode(data, validate=True))

def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
with open(filepath, "rb") as f:
data = f.read()
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
return self.load_bytes(filepath.read_bytes())

def encode_base64(
self,
Expand Down
Loading