Skip to content

Commit

Permalink
[FEAT] Math Ops for FixedSizeList / FixedShapeTensor / Embedding Type (
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 authored Jul 16, 2024
1 parent 924c905 commit b4446b0
Show file tree
Hide file tree
Showing 9 changed files with 336 additions and 22 deletions.
3 changes: 2 additions & 1 deletion src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ impl FixedSizeListArray {
&& (validity.len() * size) != flat_child.len()
{
panic!(
"FixedSizeListArray::new received values with len {} but expected it to match len of validity * size: {}",
"FixedSizeListArray::new received values with len {} but expected it to match len of validity {} * size: {}",
flat_child.len(),
validity.len(),
(validity.len() * size),
)
}
Expand Down
95 changes: 93 additions & 2 deletions src/daft-core/src/array/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use std::ops::{Add, Div, Mul, Rem, Sub};
use arrow2::{array::PrimitiveArray, compute::arithmetics::basic};

use crate::{
array::DataArray,
datatypes::{DaftNumericType, Float64Array, Int64Array, Utf8Array},
array::{DataArray, FixedSizeListArray},
datatypes::{DaftNumericType, Field, Float64Array, Int64Array, Utf8Array},
kernels::utf8::add_utf8_arrays,
DataType, Series,
};

use common_error::{DaftError, DaftResult};
Expand Down Expand Up @@ -199,3 +200,93 @@ where
}
}
}

fn fixed_sized_list_arithmetic_helper<Kernel>(
lhs: &FixedSizeListArray,
rhs: &FixedSizeListArray,
kernel: Kernel,
) -> DaftResult<FixedSizeListArray>
where
Kernel: Fn(&Series, &Series) -> DaftResult<Series>,
{
assert_eq!(lhs.fixed_element_len(), rhs.fixed_element_len());

let lhs_child: &Series = &lhs.flat_child;
let rhs_child: &Series = &rhs.flat_child;
let lhs_len = lhs.len();
let rhs_len = rhs.len();

let (result_child, validity) = match (lhs_len, rhs_len) {
(a, b) if a == b => Ok((
kernel(lhs_child, rhs_child)?,
crate::utils::arrow::arrow_bitmap_and_helper(lhs.validity(), rhs.validity()),
)),
(l, 1) => {
let validity = if rhs.is_valid(0) {
lhs.validity().cloned()
} else {
Some(arrow2::bitmap::Bitmap::new_zeroed(l))
};
Ok((kernel(lhs_child, &rhs_child.repeat(lhs_len)?)?, validity))
}
(1, r) => {
let validity = if lhs.is_valid(0) {
rhs.validity().cloned()
} else {
Some(arrow2::bitmap::Bitmap::new_zeroed(r))
};
Ok((kernel(&lhs_child.repeat(lhs_len)?, rhs_child)?, validity))
}
(a, b) => Err(DaftError::ValueError(format!(
"Cannot apply operation on arrays of different lengths: {a} vs {b}"
))),
}?;

let result_field = Field::new(
lhs.name(),
DataType::FixedSizeList(
Box::new(result_child.data_type().clone()),
lhs.fixed_element_len(),
),
);
Ok(FixedSizeListArray::new(
result_field,
result_child,
validity,
))
}

impl Add for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn add(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a + b)
}
}

impl Mul for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn mul(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a * b)
}
}

impl Sub for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn sub(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a - b)
}
}

impl Div for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn div(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a / b)
}
}

impl Rem for &FixedSizeListArray {
type Output = DaftResult<FixedSizeListArray>;
fn rem(self, rhs: Self) -> Self::Output {
fixed_sized_list_arithmetic_helper(self, rhs, |a, b| a % b)
}
}
108 changes: 90 additions & 18 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Add for &DataType {

fn add(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(
try_numeric_supertype(self, other).or(try_fixed_shape_numeric_datatype(self, other, |l, r| {l + r})).or(
match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
Expand Down Expand Up @@ -149,7 +149,7 @@ impl Sub for &DataType {

fn sub(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(
try_numeric_supertype(self, other).or(try_fixed_shape_numeric_datatype(self, other, |l, r| {l - r})).or(
match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
Expand Down Expand Up @@ -193,6 +193,7 @@ impl Div for &DataType {
self, other
))),
}
.or(try_fixed_shape_numeric_datatype(self, other, |l, r| l / r))
}
}

Expand All @@ -201,14 +202,16 @@ impl Mul for &DataType {

fn mul(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
try_numeric_supertype(self, other)
.or(try_fixed_shape_numeric_datatype(self, other, |l, r| l * r))
.or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
}
}

Expand All @@ -217,14 +220,16 @@ impl Rem for &DataType {

fn rem(self, other: Self) -> Self::Output {
use DataType::*;
try_numeric_supertype(self, other).or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
try_numeric_supertype(self, other)
.or(try_fixed_shape_numeric_datatype(self, other, |l, r| l % r))
.or(match (self, other) {
#[cfg(feature = "python")]
(Python, _) | (_, Python) => Ok(Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
self, other
))),
})
}
}

Expand Down Expand Up @@ -367,3 +372,70 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult<DataType>
l, r
)))
}

pub fn try_fixed_shape_numeric_datatype<F>(
l: &DataType,
r: &DataType,
inner_f: F,
) -> DaftResult<DataType>
where
F: Fn(&DataType, &DataType) -> DaftResult<DataType>,
{
use DataType::*;

match (l, r) {
(FixedShapeTensor(ldtype, lshape), FixedShapeTensor(rdtype, rshape)) => {
if lshape != rshape {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {} due to shape mismatch",
l, r
)))
} else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref())
&& result_type.is_numeric()
{
Ok(FixedShapeTensor(Box::new(result_type), lshape.clone()))
} else {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {}",
l, r
)))
}
}
(FixedSizeList(ldtype, lsize), FixedSizeList(rdtype, rsize)) => {
if lsize != rsize {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {} due to shape mismatch",
l, r
)))
} else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref()) {
Ok(FixedSizeList(Box::new(result_type), *lsize))
} else {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {}",
l, r
)))
}
}
(Embedding(ldtype, lsize), Embedding(rdtype, rsize)) => {
if lsize != rsize {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {} due to shape mismatch",
l, r
)))
} else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref())
&& result_type.is_numeric()
{
Ok(Embedding(Box::new(result_type), *lsize))
} else {
Err(DaftError::TypeError(format!(
"Cannot add types: {}, {}",
l, r
)))
}
}
_ => Err(DaftError::TypeError(format!(
"Invalid arguments to numeric supertype: {}, {}",
l, r
))),
}
}
10 changes: 10 additions & 0 deletions src/daft-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ impl DataType {
}
}

#[inline]
pub fn is_fixed_size_numeric(&self) -> bool {
match self {
DataType::FixedSizeList(dtype, ..)
| DataType::Embedding(dtype, ..)
| DataType::FixedShapeTensor(dtype, ..) => dtype.is_numeric(),
_ => false,
}
}

#[inline]
pub fn is_integer(&self) -> bool {
matches!(
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(let_chains)]
#![feature(int_roundings)]
#![feature(iterator_try_reduce)]
#![feature(if_let_guard)]

pub mod array;
pub mod count_mode;
Expand Down
58 changes: 57 additions & 1 deletion src/daft-core/src/series/array_impl/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
},
datatypes::{
logical::{Decimal128Array, MapArray},
FixedSizeBinaryArray, Int128Array,
Field, FixedSizeBinaryArray, Int128Array,
},
series::series_like::SeriesLike,
with_match_comparable_daft_types, with_match_integer_daft_types, with_match_numeric_daft_types,
Expand Down Expand Up @@ -61,6 +61,53 @@ macro_rules! cast_downcast_op_into_series {
}};
}

macro_rules! apply_fixed_numeric_op {
($lhs:expr, $rhs:expr, $op:ident) => {{
$lhs.$op($rhs)?
}};
}

macro_rules! fixed_sized_numeric_binary_op {
($left:expr, $right:expr, $output_type:expr, $op:ident) => {{
assert!($left.data_type().is_fixed_size_numeric());
assert!($right.data_type().is_fixed_size_numeric());

match ($left.data_type(), $right.data_type()) {
(DataType::FixedSizeList(..), DataType::FixedSizeList(..)) => {
Ok(apply_fixed_numeric_op!(
$left.downcast::<FixedSizeListArray>().unwrap(),
$right.downcast::<FixedSizeListArray>().unwrap(),
$op
)
.into_series())
}
(DataType::Embedding(..), DataType::Embedding(..)) => {
let physical = apply_fixed_numeric_op!(
&$left.downcast::<EmbeddingArray>().unwrap().physical,
&$right.downcast::<EmbeddingArray>().unwrap().physical,
$op
);
let array =
EmbeddingArray::new(Field::new($left.name(), $output_type.clone()), physical);
Ok(array.into_series())
}
(DataType::FixedShapeTensor(..), DataType::FixedShapeTensor(..)) => {
let physical = apply_fixed_numeric_op!(
&$left.downcast::<FixedShapeTensorArray>().unwrap().physical,
&$right.downcast::<FixedShapeTensorArray>().unwrap().physical,
$op
);
let array = FixedShapeTensorArray::new(
Field::new($left.name(), $output_type.clone()),
physical,
);
Ok(array.into_series())
}
(left, right) => unimplemented!("cannot add {left} and {right} types"),
}
}};
}

macro_rules! binary_op_unimplemented {
($lhs:expr, $op:expr, $rhs:expr, $output_ty:expr) => {
unimplemented!(
Expand Down Expand Up @@ -92,6 +139,9 @@ macro_rules! py_numeric_binary_op {
)
})
}
output_type if output_type.is_fixed_size_numeric() => {
fixed_sized_numeric_binary_op!(&lhs, $rhs, output_type, $op)
}
_ => binary_op_unimplemented!(lhs, $pyop, $rhs, output_type),
}
}};
Expand Down Expand Up @@ -167,6 +217,9 @@ pub(crate) trait SeriesBinaryOps: SeriesLike {
cast_downcast_op_into_series!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, add)
})
}
output_type if output_type.is_fixed_size_numeric() => {
fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, add)
}
_ => binary_op_unimplemented!(lhs, "+", rhs, output_type),
}
}
Expand All @@ -184,6 +237,9 @@ pub(crate) trait SeriesBinaryOps: SeriesLike {
#[cfg(feature = "python")]
Python => Ok(py_binary_op!(lhs, rhs, "truediv")),
Float64 => cast_downcast_op_into_series!(lhs, rhs, &Float64, Float64Array, div),
output_type if output_type.is_fixed_size_numeric() => {
fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, div)
}
_ => binary_op_unimplemented!(lhs, "/", rhs, output_type),
}
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub mod minhash;
pub mod not;
pub mod null;
pub mod partitioning;
pub mod repeat;
pub mod round;
pub mod search_sorted;
pub mod shift;
Expand Down
Loading

0 comments on commit b4446b0

Please sign in to comment.