Skip to content

Commit 6df8f84

Browse files
committed
feat: expose image attribute as expression
1 parent 1409f8f commit 6df8f84

File tree

7 files changed

+192
-7
lines changed

7 files changed

+192
-7
lines changed

daft/expressions/expressions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5153,6 +5153,18 @@ def to_mode(self, mode: str | ImageMode) -> Expression:
51535153
image_mode = lit(mode)._expr
51545154
f = native.get_function_from_registry("to_mode")
51555155
return Expression._from_pyexpr(f(self._expr, mode=image_mode))
5156+
5157+
def attribute(self, name: str) -> Expression:
5158+
"""Get an attribute of the image, such as 'width', 'height', or 'mode'.
5159+
5160+
Args:
5161+
name (str): The name of the attribute to retrieve.
5162+
5163+
Returns:
5164+
Expression: An Expression representing the requested attribute.
5165+
"""
5166+
f = native.get_function_from_registry("image_attribute")
5167+
return Expression._from_pyexpr(f(self._expr, lit(name)._expr))
51565168

51575169

51585170
class ExpressionPartitioningNamespace(ExpressionNamespace):

src/daft-core/src/array/image_array.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,39 @@ impl ImageArray {
5656
}
5757

5858
pub fn channel_array(&self) -> &arrow2::array::UInt16Array {
59-
let array = self.physical.children.get(Self::IMAGE_CHANNEL_IDX).unwrap();
60-
array.u16().unwrap().as_arrow()
59+
self.channels().as_arrow()
6160
}
6261

6362
pub fn height_array(&self) -> &arrow2::array::UInt32Array {
64-
let array = self.physical.children.get(Self::IMAGE_HEIGHT_IDX).unwrap();
65-
array.u32().unwrap().as_arrow()
63+
self.heights().as_arrow()
6664
}
6765

6866
pub fn width_array(&self) -> &arrow2::array::UInt32Array {
69-
let array = self.physical.children.get(Self::IMAGE_WIDTH_IDX).unwrap();
70-
array.u32().unwrap().as_arrow()
67+
self.widths().as_arrow()
7168
}
7269

7370
pub fn mode_array(&self) -> &arrow2::array::UInt8Array {
71+
self.modes().as_arrow()
72+
}
73+
74+
pub fn channels(&self) -> &DataArray<UInt16Type> {
75+
let array = self.physical.children.get(Self::IMAGE_CHANNEL_IDX).unwrap();
76+
array.u16().unwrap()
77+
}
78+
79+
pub fn heights(&self) -> &DataArray<UInt32Type> {
80+
let array = self.physical.children.get(Self::IMAGE_HEIGHT_IDX).unwrap();
81+
array.u32().unwrap()
82+
}
83+
84+
pub fn widths(&self) -> &DataArray<UInt32Type> {
85+
let array = self.physical.children.get(Self::IMAGE_WIDTH_IDX).unwrap();
86+
array.u32().unwrap()
87+
}
88+
89+
pub fn modes(&self) -> &DataArray<UInt8Type> {
7490
let array = self.physical.children.get(Self::IMAGE_MODE_IDX).unwrap();
75-
array.u8().unwrap().as_arrow()
91+
array.u8().unwrap()
7692
}
7793

7894
pub fn from_list_array(
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use common_error::{DaftError, DaftResult};
2+
use daft_core::prelude::*;
3+
use daft_dsl::{
4+
functions::{FunctionArgs, ScalarUDF},
5+
ExprRef,
6+
};
7+
use serde::{Deserialize, Serialize};
8+
9+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
10+
pub struct ImageAttribute;
11+
12+
#[derive(FunctionArgs)]
13+
struct ImageAttributeArgs<T> {
14+
input: T,
15+
attr: String,
16+
}
17+
18+
#[typetag::serde]
19+
impl ScalarUDF for ImageAttribute {
20+
fn call(&self, inputs: FunctionArgs<Series>) -> DaftResult<Series> {
21+
let ImageAttributeArgs { input, attr } = inputs.try_into()?;
22+
crate::series::attribute(&input, &attr)
23+
}
24+
25+
fn name(&self) -> &'static str {
26+
"image_attribute"
27+
}
28+
29+
fn get_return_field(
30+
&self,
31+
inputs: FunctionArgs<ExprRef>,
32+
schema: &Schema,
33+
) -> DaftResult<Field> {
34+
let ImageAttributeArgs { input, attr } = inputs.try_into()?;
35+
36+
let input_field = input.to_field(schema)?;
37+
match input_field.dtype {
38+
DataType::Image(_) | DataType::FixedShapeImage(..) => {
39+
Ok(Field::new(input_field.name, DataType::UInt32))
40+
}
41+
_ => Err(DaftError::TypeError(format!(
42+
"Image attribute can only be retrieved from ImageArrays, got {}",
43+
input_field.dtype
44+
))),
45+
}
46+
}
47+
48+
fn docstring(&self) -> &'static str {
49+
"Extracts metadata attributes from image series (height/width/channels/mode)"
50+
}
51+
}

src/daft-image/src/functions/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod decode;
55
pub mod encode;
66
pub mod resize;
77
pub mod to_mode;
8+
pub mod attribute;
89

910
pub struct ImageFunctions;
1011

@@ -15,5 +16,6 @@ impl FunctionModule for ImageFunctions {
1516
parent.add_fn(encode::ImageEncode);
1617
parent.add_fn(resize::ImageResize);
1718
parent.add_fn(to_mode::ImageToMode);
19+
parent.add_fn(attribute::ImageAttribute);
1820
}
1921
}

src/daft-image/src/ops.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub trait ImageOps {
3838
fn to_mode(&self, mode: ImageMode) -> DaftResult<Self>
3939
where
4040
Self: Sized;
41+
fn attribute(&self, attr: &str) -> DaftResult<DataArray<UInt32Type>>;
4142
}
4243

4344
pub(crate) fn image_array_from_img_buffers(
@@ -194,8 +195,22 @@ impl ImageOps for ImageArray {
194195
.collect();
195196
image_array_from_img_buffers(self.name(), &buffers, Some(mode))
196197
}
198+
199+
fn attribute(&self, attr: &str) -> DaftResult<DataArray<UInt32Type>> {
200+
match attr.to_lowercase().as_str() {
201+
"height" | "heights" => Ok(self.heights().clone().rename(self.name())),
202+
"width" | "widths" => Ok(self.widths().clone().rename(self.name())),
203+
"channel" | "channels" => Ok(self.channels().clone().cast(&DataType::UInt32)?.u32()?.clone().rename(self.name())),
204+
"mode" | "modes" => Ok(self.modes().clone().cast(&DataType::UInt32)?.u32()?.clone().rename(self.name())),
205+
_ => Err(DaftError::ValueError(format!(
206+
"Unsupported image attribute: {}, available: [heights, widths, channels, modes]",
207+
attr
208+
))),
209+
}
210+
}
197211
}
198212

213+
199214
impl ImageOps for FixedShapeImageArray {
200215
fn encode(&self, image_format: ImageFormat) -> DaftResult<BinaryArray> {
201216
encode_images(self, image_format)
@@ -254,6 +269,36 @@ impl ImageOps for FixedShapeImageArray {
254269
};
255270
fixed_image_array_from_img_buffers(self.name(), &buffers, &mode, *height, *width)
256271
}
272+
273+
fn attribute(&self, attr: &str) -> DaftResult<DataArray<UInt32Type>> {
274+
let (height, width) = match self.data_type() {
275+
DataType::FixedShapeImage(_, h, w) => (h, w),
276+
_ => unreachable!("Should be FixedShapeImage type"),
277+
};
278+
279+
match attr.to_lowercase().as_str() {
280+
"height" | "heights" => Ok(UInt32Array::from((
281+
self.name(),
282+
vec![*height; self.len()].as_slice(),
283+
))),
284+
"width" | "widths" => Ok(UInt32Array::from((
285+
self.name(),
286+
vec![*width; self.len()].as_slice(),
287+
))),
288+
"channel" | "channels" => Ok(UInt32Array::from((
289+
self.name(),
290+
vec![self.image_mode().num_channels() as u32; self.len()].as_slice(),
291+
))),
292+
"mode" | "modes" => Ok(UInt32Array::from((
293+
self.name(),
294+
vec![(*self.image_mode() as u8) as u32; self.len()].as_slice(),
295+
))),
296+
_ => Err(DaftError::ValueError(format!(
297+
"Unsupported image attribute: {}, available: [heights, widths, channels, modes]",
298+
attr
299+
))),
300+
}
301+
}
257302
}
258303

259304
impl AsImageObj for ImageArray {

src/daft-image/src/series.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,30 @@ pub fn to_mode(s: &Series, mode: ImageMode) -> DaftResult<Series> {
197197
))),
198198
}
199199
}
200+
201+
202+
/// Get metadata attributes from image series
203+
///
204+
/// # Arguments
205+
/// * `s` - Input Series containing image data
206+
/// * `attr` - Attribute name to retrieve ("height", "width", "channels", "mode")
207+
///
208+
/// # Returns
209+
/// Series of UInt32 values containing requested attribute
210+
pub fn attribute(s: &Series, attr: &str) -> DaftResult<Series> {
211+
match s.data_type() {
212+
DataType::Image(_) => {
213+
let array = s.downcast::<ImageArray>()?;
214+
Ok(array.attribute(attr)?.into_series())
215+
}
216+
DataType::FixedShapeImage(..) => {
217+
let array = s.downcast::<FixedShapeImageArray>()?;
218+
Ok(array.attribute(attr)?.into_series())
219+
}
220+
dt => Err(DaftError::ValueError(format!(
221+
"datatype: {} does not support Image attributes. Occurred while processing Series: {}",
222+
dt,
223+
s.name()
224+
))),
225+
}
226+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
import daft
6+
from daft import DataType
7+
from daft.recordbatch import MicroPartition
8+
from tests.recordbatch.image.conftest import MODE_TO_NUM_CHANNELS
9+
10+
11+
@pytest.mark.parametrize("attr", ["width", "height", "channels", "mode"])
12+
def test_image_attribute_mixed_shape(mixed_shape_data_fixture, attr):
13+
table = daft.from_pydict({"images": mixed_shape_data_fixture})
14+
table = table.with_column(attr, daft.col("images").image.attribute(attr))
15+
values = table.to_pydict()[attr]
16+
17+
image_dtype = mixed_shape_data_fixture.datatype()
18+
mode = image_dtype.image_mode
19+
mode_name = str(mode).split(".")[-1]
20+
21+
if attr == "width":
22+
assert all(x in (3, 4) for x in values if x is not None)
23+
elif attr == "height":
24+
assert all(x in (2, 3) for x in values if x is not None)
25+
elif attr == "channels":
26+
expected_channels = MODE_TO_NUM_CHANNELS[mode_name]
27+
assert all(x == expected_channels for x in values if x is not None)
28+
elif attr == "mode":
29+
expected_mode = MODE_TO_NUM_CHANNELS[mode_name]
30+
assert all(x == expected_mode for x in values if x is not None)
31+
32+

0 commit comments

Comments
 (0)