Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add image mode casting #2562

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ if TYPE_CHECKING:
class ImageMode(Enum):
"""
Supported image modes for Daft's image type.

.. warning::
Currently, only the 8-bit modes (L, LA, RGB, RGBA) can be stored in a DataFrame.
If your binary image data includes other modes, use the `mode` argument
in `image.decode` to convert the images to a supported mode.
"""

#: 8-bit grayscale
Expand Down Expand Up @@ -1133,10 +1138,11 @@ class PyExpr:
def utf8_to_date(self, format: str) -> PyExpr: ...
def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ...
def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ...
def image_decode(self, raise_error_on_failure: bool) -> PyExpr: ...
def image_decode(self, raise_error_on_failure: bool, mode: ImageMode | None = None) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def image_to_mode(self, mode: ImageMode) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_count(self, mode: CountMode) -> PyExpr: ...
def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ...
Expand Down Expand Up @@ -1329,9 +1335,10 @@ class PySeries:
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def list_slice(self, start: PySeries, end: PySeries) -> PySeries: ...
def map_get(self, key: PySeries) -> PySeries: ...
def image_decode(self, raise_error_on_failure: bool) -> PySeries: ...
def image_decode(self, raise_error_on_failure: bool, mode: ImageMode | None = None) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
def image_resize(self, w: int, h: int) -> PySeries: ...
def image_to_mode(self, mode: ImageMode) -> PySeries: ...
def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ...
def is_null(self) -> PySeries: ...
def not_null(self) -> PySeries: ...
Expand Down
24 changes: 21 additions & 3 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import daft.daft as native
from daft import context
from daft.daft import CountMode, ImageFormat
from daft.daft import CountMode, ImageFormat, ImageMode
from daft.daft import PyExpr as _PyExpr
from daft.daft import col as _col
from daft.daft import date_lit as _date_lit
Expand Down Expand Up @@ -2881,14 +2881,20 @@ def __repr__(self) -> str:
class ExpressionImageNamespace(ExpressionNamespace):
"""Expression operations for image columns."""

def decode(self, on_error: Literal["raise"] | Literal["null"] = "raise") -> Expression:
def decode(
self,
on_error: Literal["raise"] | Literal["null"] = "raise",
mode: str | ImageMode | None = None,
) -> Expression:
"""
Decodes the binary data in this column into images.
Vince7778 marked this conversation as resolved.
Show resolved Hide resolved

This can only be applied to binary columns that contain encoded images (e.g. PNG, JPEG, etc.)

Args:
on_error: Whether to raise when encountering an error, or log a warning and return a null
mode: What mode to convert the images into before storing it in the column. This may prevent
errors relating to unsupported types.

Returns:
Expression: An Image expression represnting an image column.
Expand All @@ -2901,7 +2907,12 @@ def decode(self, on_error: Literal["raise"] | Literal["null"] = "raise") -> Expr
else:
raise NotImplementedError(f"Unimplemented on_error option: {on_error}.")

return Expression._from_pyexpr(self._expr.image_decode(raise_error_on_failure=raise_on_error))
if mode is not None:
if isinstance(mode, str):
mode = ImageMode.from_mode_string(mode.upper())
if not isinstance(mode, ImageMode):
raise ValueError(f"mode must be a string or ImageMode variant, but got: {mode}")
return Expression._from_pyexpr(self._expr.image_decode(raise_error_on_failure=raise_on_error, mode=mode))

def encode(self, image_format: str | ImageFormat) -> Expression:
"""
Expand Down Expand Up @@ -2958,6 +2969,13 @@ def crop(self, bbox: tuple[int, int, int, int] | Expression) -> Expression:
assert isinstance(bbox, Expression)
return Expression._from_pyexpr(self._expr.image_crop(bbox._expr))

def to_mode(self, mode: str | ImageMode) -> Expression:
if isinstance(mode, str):
mode = ImageMode.from_mode_string(mode.upper())
if not isinstance(mode, ImageMode):
raise ValueError(f"mode must be a string or ImageMode variant, but got: {mode}")
return Expression._from_pyexpr(self._expr.image_to_mode(mode))


class ExpressionPartitioningNamespace(ExpressionNamespace):
def days(self) -> Expression:
Expand Down
22 changes: 19 additions & 3 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyarrow as pa

from daft.arrow_utils import ensure_array, ensure_chunked_array
from daft.daft import CountMode, ImageFormat, PySeries
from daft.daft import CountMode, ImageFormat, ImageMode, PySeries
from daft.datatype import DataType
from daft.utils import pyarrow_supports_fixed_shape_tensor

Expand Down Expand Up @@ -975,15 +975,24 @@ def get(self, key: Series) -> Series:


class SeriesImageNamespace(SeriesNamespace):
def decode(self, on_error: Literal["raise"] | Literal["null"] = "raise") -> Series:
def decode(
self,
on_error: Literal["raise"] | Literal["null"] = "raise",
mode: str | ImageMode | None = None,
) -> Series:
raise_on_error = False
if on_error == "raise":
raise_on_error = True
elif on_error == "null":
raise_on_error = False
else:
raise NotImplementedError(f"Unimplemented on_error option: {on_error}.")
return Series._from_pyseries(self._series.image_decode(raise_error_on_failure=raise_on_error))
if mode is not None:
if isinstance(mode, str):
mode = ImageMode.from_mode_string(mode.upper())
if not isinstance(mode, ImageMode):
raise ValueError(f"mode must be a string or ImageMode variant, but got: {mode}")
return Series._from_pyseries(self._series.image_decode(raise_error_on_failure=raise_on_error, mode=mode))

def encode(self, image_format: str | ImageFormat) -> Series:
if isinstance(image_format, str):
Expand All @@ -999,3 +1008,10 @@ def resize(self, w: int, h: int) -> Series:
raise TypeError(f"expected int for h but got {type(h)}")

return Series._from_pyseries(self._series.image_resize(w, h))

def to_mode(self, mode: str | ImageMode) -> Series:
if isinstance(mode, str):
mode = ImageMode.from_mode_string(mode.upper())
if not isinstance(mode, ImageMode):
raise ValueError(f"mode must be a string or ImageMode variant, but got: {mode}")
return Series._from_pyseries(self._series.image_to_mode(mode))
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ Image
Expression.image.encode
Expression.image.resize
Expression.image.crop
Expression.image.to_mode

Partitioning
############
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ API Documentation
udf
series
configs
misc
17 changes: 17 additions & 0 deletions docs/source/api_docs/misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Miscellaneous
==================

.. currentmodule:: daft

Types
--------------

Image Types
~~~~~~~~~~~~~~

.. autosummary::
:nosignatures:
:toctree: doc_gen/misc

ImageMode
ImageFormat
82 changes: 79 additions & 3 deletions src/daft-core/src/array/ops/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,24 @@ impl<'a> DaftImageBuffer<'a> {
_ => unimplemented!("Mode {self:?} not implemented"),
}
}

pub fn into_mode(self, mode: ImageMode) -> Self {
let img: DynamicImage = self.into();
// I couldn't find a method from the image crate to do this
let img: DynamicImage = match mode {
ImageMode::L => img.into_luma8().into(),
ImageMode::LA => img.into_luma_alpha8().into(),
ImageMode::RGB => img.into_rgb8().into(),
ImageMode::RGBA => img.into_rgba8().into(),
ImageMode::L16 => img.into_luma16().into(),
ImageMode::LA16 => img.into_luma_alpha16().into(),
ImageMode::RGB16 => img.into_rgb16().into(),
ImageMode::RGBA16 => img.into_rgba16().into(),
ImageMode::RGB32F => img.into_rgb32f().into(),
ImageMode::RGBA32F => img.into_rgba32f().into(),
};
img.into()
}
}

fn image_buffer_vec_to_cow<'a, P, T>(input: ImageBuffer<P, Vec<T>>) -> ImageBuffer<P, Cow<'a, [T]>>
Expand All @@ -272,6 +290,19 @@ where
ImageBuffer::from_raw(w, h, owned).unwrap()
}

fn image_buffer_cow_to_vec<P, T>(input: ImageBuffer<P, Cow<[T]>>) -> ImageBuffer<P, Vec<T>>
where
P: image::Pixel<Subpixel = T>,
Vec<T>: Deref<Target = [P::Subpixel]>,
T: ToOwned + std::clone::Clone,
[T]: ToOwned,
{
let h = input.height();
let w = input.width();
let owned: Vec<T> = input.into_raw().to_vec();
ImageBuffer::from_raw(w, h, owned).unwrap()
}

impl<'a> From<DynamicImage> for DaftImageBuffer<'a> {
fn from(dyn_img: DynamicImage) -> Self {
match dyn_img {
Expand Down Expand Up @@ -310,6 +341,23 @@ impl<'a> From<DynamicImage> for DaftImageBuffer<'a> {
}
}

impl<'a> From<DaftImageBuffer<'a>> for DynamicImage {
fn from(daft_buf: DaftImageBuffer<'a>) -> Self {
match daft_buf {
DaftImageBuffer::L(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::LA(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::RGB(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::RGBA(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::L16(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::LA16(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::RGB16(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::RGBA16(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::RGB32F(buf) => image_buffer_cow_to_vec(buf).into(),
DaftImageBuffer::RGBA32F(buf) => image_buffer_cow_to_vec(buf).into(),
}
}
}

pub struct ImageArraySidecarData {
pub channels: Vec<u16>,
pub heights: Vec<u32>,
Expand Down Expand Up @@ -569,6 +617,14 @@ impl ImageArray {
},
)
}

pub fn to_mode(&self, mode: ImageMode) -> DaftResult<Self> {
let buffers: Vec<Option<DaftImageBuffer>> = self
.into_iter()
.map(|img| img.map(|img| img.into_mode(mode)))
.collect();
Self::from_daft_image_buffers(self.name(), &buffers, &Some(mode))
}
}

impl AsImageObj for ImageArray {
Expand Down Expand Up @@ -723,6 +779,19 @@ impl FixedShapeImageArray {
let result = crop_images(self, &mut bboxes_iterator);
ImageArray::from_daft_image_buffers(self.name(), result.as_slice(), &Some(self.mode()))
}

pub fn to_mode(&self, mode: ImageMode) -> DaftResult<Self> {
let buffers: Vec<Option<DaftImageBuffer>> = self
.into_iter()
.map(|img| img.map(|img| img.into_mode(mode)))
.collect();

let (height, width) = match self.data_type() {
DataType::FixedShapeImage(_, h, w) => (h, w),
_ => unreachable!("self should always be a FixedShapeImage"),
};
Self::from_daft_image_buffers(self.name(), &buffers, &mode, *height, *width)
}
}

impl AsImageObj for FixedShapeImageArray {
Expand Down Expand Up @@ -787,7 +856,11 @@ where
}

impl BinaryArray {
pub fn image_decode(&self, raise_error_on_failure: bool) -> DaftResult<ImageArray> {
pub fn image_decode(
&self,
raise_error_on_failure: bool,
mode: Option<ImageMode>,
) -> DaftResult<ImageArray> {
let arrow_array = self
.data()
.as_any()
Expand All @@ -798,7 +871,7 @@ impl BinaryArray {
// Load images from binary buffers.
// Confirm that all images have the same value dtype.
for (index, row) in arrow_array.iter().enumerate() {
let img_buf = match row.map(DaftImageBuffer::decode).transpose() {
let mut img_buf = match row.map(DaftImageBuffer::decode).transpose() {
Ok(val) => val,
Err(err) => {
if raise_error_on_failure {
Expand All @@ -812,6 +885,9 @@ impl BinaryArray {
}
}
};
if let Some(mode) = mode {
img_buf = img_buf.map(|buf| buf.into_mode(mode));
}
let dtype = img_buf.as_ref().map(|im| im.mode().get_dtype());
match (dtype.as_ref(), cached_dtype.as_ref()) {
(Some(t1), Some(t2)) => {
Expand All @@ -829,7 +905,7 @@ impl BinaryArray {
// Fall back to UInt8 dtype if series is all nulls.
let cached_dtype = cached_dtype.unwrap_or(DataType::UInt8);
match cached_dtype {
DataType::UInt8 => Ok(ImageArray::from_daft_image_buffers(self.name(), img_bufs.as_slice(), &None)?),
DataType::UInt8 => Ok(ImageArray::from_daft_image_buffers(self.name(), img_bufs.as_slice(), &mode)?),
_ => unimplemented!("Decoding images of dtype {cached_dtype:?} is not supported, only uint8 images are supported."),
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/daft-core/src/datatypes/image_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ use common_error::{DaftError, DaftResult};

/// Supported image modes for Daft's image type.
///
/// .. warning::
/// Currently, only the 8-bit modes (L, LA, RGB, RGBA) can be stored in a DataFrame.
/// If your binary image data includes other modes, use the `mode` argument
/// in `image.decode` to convert the images to a supported mode.
///
/// | L - 8-bit grayscale
/// | LA - 8-bit grayscale + alpha
/// | RGB - 8-bit RGB
Expand Down
15 changes: 13 additions & 2 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,15 @@ impl PySeries {
Ok(self.series.map_get(&key.series)?.into())
}

pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult<Self> {
Ok(self.series.image_decode(raise_error_on_failure)?.into())
pub fn image_decode(
&self,
raise_error_on_failure: bool,
mode: Option<ImageMode>,
) -> PyResult<Self> {
Ok(self
.series
.image_decode(raise_error_on_failure, mode)?
.into())
}

pub fn image_encode(&self, image_format: ImageFormat) -> PyResult<Self> {
Expand All @@ -672,6 +679,10 @@ impl PySeries {
Ok(self.series.image_resize(w as u32, h as u32)?.into())
}

pub fn image_to_mode(&self, mode: &ImageMode) -> PyResult<Self> {
Ok(self.series.image_to_mode(*mode)?.into())
}

pub fn if_else(&self, other: &Self, predicate: &Self) -> PyResult<Self> {
Ok(self
.series
Expand Down
Loading
Loading