Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Math Ops for FixedSizeList / FixedShapeTensor / Embedding Type #2507

Merged
merged 7 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
&& (validity.len() * size) != flat_child.len()
{
panic!(
"FixedSizeListArray::new received values with len {} but expected it to match len of validity * size: {}",
"FixedSizeListArray::new received values with len {} but expected it to match len of validity {} * size: {}",

Check warning on line 36 in src/daft-core/src/array/fixed_size_list_array.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/fixed_size_list_array.rs#L36

Added line #L36 was not covered by tests
flat_child.len(),
validity.len(),

Check warning on line 38 in src/daft-core/src/array/fixed_size_list_array.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/fixed_size_list_array.rs#L38

Added line #L38 was not covered by tests
(validity.len() * size),
)
}
Expand Down
95 changes: 93 additions & 2 deletions src/daft-core/src/array/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
use arrow2::{array::PrimitiveArray, compute::arithmetics::basic};

use crate::{
array::DataArray,
datatypes::{DaftNumericType, Float64Array, Int64Array, Utf8Array},
array::{DataArray, FixedSizeListArray},
datatypes::{DaftNumericType, Field, Float64Array, Int64Array, Utf8Array},
kernels::utf8::add_utf8_arrays,
DataType, Series,
};

use common_error::{DaftError, DaftResult};
Expand Down Expand Up @@ -199,3 +200,93 @@
}
}
}

fn fixed_sized_list_arithmetic_helper<Kernel>(
lhs: &FixedSizeListArray,
rhs: &FixedSizeListArray,
kernel: Kernel,
) -> DaftResult<FixedSizeListArray>
where
Kernel: Fn(&Series, &Series) -> DaftResult<Series>,
{
assert_eq!(lhs.fixed_element_len(), rhs.fixed_element_len());

let lhs_child: &Series = &lhs.flat_child;
let rhs_child: &Series = &rhs.flat_child;
let lhs_len = lhs.len();
let rhs_len = rhs.len();

let (result_child, validity) = match (lhs_len, rhs_len) {
(a, b) if a == b => Ok((
kernel(lhs_child, rhs_child)?,
crate::utils::arrow::arrow_bitmap_and_helper(lhs.validity(), rhs.validity()),
)),
(l, 1) => {
let validity = if rhs.is_valid(0) {
lhs.validity().cloned()
} else {
Some(arrow2::bitmap::Bitmap::new_zeroed(l))
};
Ok((kernel(lhs_child, &rhs_child.repeat(lhs_len)?)?, validity))
}
(1, r) => {
let validity = if lhs.is_valid(0) {
rhs.validity().cloned()

Check warning on line 234 in src/daft-core/src/array/ops/arithmetic.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/arithmetic.rs#L232-L234

Added lines #L232 - L234 were not covered by tests
} else {
Some(arrow2::bitmap::Bitmap::new_zeroed(r))

Check warning on line 236 in src/daft-core/src/array/ops/arithmetic.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/arithmetic.rs#L236

Added line #L236 was not covered by tests
};
Ok((kernel(&lhs_child.repeat(lhs_len)?, rhs_child)?, validity))

Check warning on line 238 in src/daft-core/src/array/ops/arithmetic.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/arithmetic.rs#L238

Added line #L238 was not covered by tests
}
(a, b) => Err(DaftError::ValueError(format!(
"Cannot apply operation on arrays of different lengths: {a} vs {b}"
))),
}?;

Check warning on line 243 in src/daft-core/src/array/ops/arithmetic.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/arithmetic.rs#L240-L243

Added lines #L240 - L243 were not covered by tests

let result_field = Field::new(
lhs.name(),
DataType::FixedSizeList(
Box::new(result_child.data_type().clone()),
lhs.fixed_element_len(),
),
);
Ok(FixedSizeListArray::new(
result_field,
result_child,
validity,
))
}

impl Add for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn add(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a + b)
}
}

impl Mul for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn mul(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a * b)
}
}

impl Sub for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn sub(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a - b)
}
}

impl Div for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn div(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a / b)
}
}

impl Rem for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn rem(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a % b)
}

Check warning on line 291 in src/daft-core/src/array/ops/arithmetic.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/arithmetic.rs#L289-L291

Added lines #L289 - L291 were not covered by tests
}
108 changes: 90 additions & 18 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

fn add(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(
try_numeric_supertype(self, other).or(try_fixed_shape_numeric_datatype(self, other, |l, r| {l + r})).or(
match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
Expand Down Expand Up @@ -149,7 +149,7 @@

fn sub(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(
try_numeric_supertype(self, other).or(try_fixed_shape_numeric_datatype(self, other, |l, r| {l - r})).or(
match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
Expand Down Expand Up @@ -193,6 +193,7 @@
self, other
))),
}
.or(try_fixed_shape_numeric_datatype(self, other, |l, r| l / r))
}
}

Expand All @@ -201,14 +202,16 @@

fn mul(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
try_numeric_supertype(self, other)
.or(try_fixed_shape_numeric_datatype(self, other, |l, r| l * r))
.or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
}
}

Expand All @@ -217,14 +220,16 @@

fn rem(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
try_numeric_supertype(self, other)
.or(try_fixed_shape_numeric_datatype(self, other, |l, r| l % r))
.or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
}
}

Expand Down Expand Up @@ -367,3 +372,70 @@
l, r
)))
}

pub fn try_fixed_shape_numeric_datatype<F>(
l: &DataType,
r: &DataType,
inner_f: F,
) -> DaftResult<DataType>
where
F: Fn(&DataType, &DataType) -> DaftResult<DataType>,
{
use DataType::*;

match (l, r) {
(FixedShapeTensor(ldtype, lshape), FixedShapeTensor(rdtype, rshape)) => {
if lshape != rshape {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {} due to shape mismatch",
l, r
)))

Check warning on line 392 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L389-L392

Added lines #L389 - L392 were not covered by tests
} else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref())
&& result_type.is_numeric()
{
Ok(FixedShapeTensor(Box::new(result_type), lshape.clone()))
} else {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {}",
l, r
)))

Check warning on line 401 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L398-L401

Added lines #L398 - L401 were not covered by tests
}
}
(FixedSizeList(ldtype, lsize), FixedSizeList(rdtype, rsize)) => {
if lsize != rsize {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {} due to shape mismatch",
l, r
)))
} else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref()) {
Ok(FixedSizeList(Box::new(result_type), *lsize))

Check warning on line 411 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L404-L411

Added lines #L404 - L411 were not covered by tests
} else {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {}",
l, r
)))

Check warning on line 416 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L413-L416

Added lines #L413 - L416 were not covered by tests
}
}
(Embedding(ldtype, lsize), Embedding(rdtype, rsize)) => {
if lsize != rsize {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {} due to shape mismatch",
l, r
)))
} else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref())
&& result_type.is_numeric()

Check warning on line 426 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L419-L426

Added lines #L419 - L426 were not covered by tests
{
Ok(Embedding(Box::new(result_type), *lsize))

Check warning on line 428 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L428

Added line #L428 was not covered by tests
} else {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {}",
l, r
)))

Check warning on line 433 in src/daft-core/src/datatypes/binary_ops.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/binary_ops.rs#L430-L433

Added lines #L430 - L433 were not covered by tests
}
}
_ => Err(DaftError::TypeError(format!(
"Invalid arguments to numeric supertype: {}, {}",
l, r
))),
}
}
10 changes: 10 additions & 0 deletions src/daft-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@
}
}

#[inline]
pub fn is_fixed_size_numeric(&self) -> bool {
match self {
DataType::FixedSizeList(dtype, ..)
| DataType::Embedding(dtype, ..)

Check warning on line 277 in src/daft-core/src/datatypes/dtype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/dtype.rs#L276-L277

Added lines #L276 - L277 were not covered by tests
| DataType::FixedShapeTensor(dtype, ..) => dtype.is_numeric(),
_ => false,

Check warning on line 279 in src/daft-core/src/datatypes/dtype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/datatypes/dtype.rs#L279

Added line #L279 was not covered by tests
}
}

#[inline]
pub fn is_integer(&self) -> bool {
matches!(
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(let_chains)]
#![feature(int_roundings)]
#![feature(iterator_try_reduce)]
#![feature(if_let_guard)]

pub mod array;
pub mod count_mode;
Expand Down
58 changes: 57 additions & 1 deletion src/daft-core/src/series/array_impl/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
},
datatypes::{
logical::{Decimal128Array, MapArray},
FixedSizeBinaryArray, Int128Array,
Field, FixedSizeBinaryArray, Int128Array,
},
series::series_like::SeriesLike,
with_match_comparable_daft_types, with_match_integer_daft_types, with_match_numeric_daft_types,
Expand Down Expand Up @@ -61,6 +61,53 @@ macro_rules! cast_downcast_op_into_series {
}};
}

macro_rules! apply_fixed_numeric_op {
($lhs:expr, $rhs:expr, $op:ident) => {{
$lhs.$op($rhs)?
}};
}

macro_rules! fixed_sized_numeric_binary_op {
($left:expr, $right:expr, $output_type:expr, $op:ident) => {{
assert!($left.data_type().is_fixed_size_numeric());
assert!($right.data_type().is_fixed_size_numeric());

match ($left.data_type(), $right.data_type()) {
(DataType::FixedSizeList(..), DataType::FixedSizeList(..)) => {
Ok(apply_fixed_numeric_op!(
$left.downcast::<FixedSizeListArray>().unwrap(),
$right.downcast::<FixedSizeListArray>().unwrap(),
$op
)
.into_series())
}
(DataType::Embedding(..), DataType::Embedding(..)) => {
let physical = apply_fixed_numeric_op!(
&$left.downcast::<EmbeddingArray>().unwrap().physical,
&$right.downcast::<EmbeddingArray>().unwrap().physical,
$op
);
let array =
EmbeddingArray::new(Field::new($left.name(), $output_type.clone()), physical);
Ok(array.into_series())
}
(DataType::FixedShapeTensor(..), DataType::FixedShapeTensor(..)) => {
let physical = apply_fixed_numeric_op!(
&$left.downcast::<FixedShapeTensorArray>().unwrap().physical,
&$right.downcast::<FixedShapeTensorArray>().unwrap().physical,
$op
);
let array = FixedShapeTensorArray::new(
Field::new($left.name(), $output_type.clone()),
physical,
);
Ok(array.into_series())
}
(left, right) => unimplemented!("cannot add {left} and {right} types"),
}
}};
}

macro_rules! binary_op_unimplemented {
($lhs:expr, $op:expr, $rhs:expr, $output_ty:expr) => {
unimplemented!(
Expand Down Expand Up @@ -92,6 +139,9 @@ macro_rules! py_numeric_binary_op {
)
})
}
output_type if output_type.is_fixed_size_numeric() => {
fixed_sized_numeric_binary_op!(&lhs, $rhs, output_type, $op)
}
_ => binary_op_unimplemented!(lhs, $pyop, $rhs, output_type),
}
}};
Expand Down Expand Up @@ -167,6 +217,9 @@ pub(crate) trait SeriesBinaryOps: SeriesLike {
cast_downcast_op_into_series!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, add)
})
}
output_type if output_type.is_fixed_size_numeric() => {
fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, add)
}
_ => binary_op_unimplemented!(lhs, "+", rhs, output_type),
}
}
Expand All @@ -184,6 +237,9 @@ pub(crate) trait SeriesBinaryOps: SeriesLike {
#[cfg(feature = "python")]
Python => Ok(py_binary_op!(lhs, rhs, "truediv")),
Float64 => cast_downcast_op_into_series!(lhs, rhs, &Float64, Float64Array, div),
output_type if output_type.is_fixed_size_numeric() => {
fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, div)
}
_ => binary_op_unimplemented!(lhs, "/", rhs, output_type),
}
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub mod minhash;
pub mod not;
pub mod null;
pub mod partitioning;
pub mod repeat;
pub mod round;
pub mod search_sorted;
pub mod shift;
Expand Down
Loading
Loading