diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 582b65212d..a8f8d46bdb 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -69,10 +69,13 @@ where // Get the result of the Arrow Logical->Target cast. let result_arrow_array = { // First, get corresponding Arrow LogicalArray of source DataArray - use DataType::*; let source_arrow_array = match source_dtype { // Wrapped primitives - Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => { + DataType::Decimal128(..) + | DataType::Date + | DataType::Timestamp(..) + | DataType::Duration(..) + | DataType::Time(..) => { with_match_daft_logical_primitive_types!(source_dtype, |$T| { use arrow2::array::Array; to_cast @@ -111,11 +114,14 @@ where // If the target type is also Logical, get the Arrow Physical. let result_arrow_physical_array = { if dtype.is_logical() { - use DataType::*; let target_physical_type = dtype.to_physical().to_arrow()?; match dtype { // Primitive wrapper types: change the arrow2 array's type field to primitive - Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => { + DataType::Decimal128(..) + | DataType::Date + | DataType::Timestamp(..) + | DataType::Duration(..) + | DataType::Time(..) => { with_match_daft_logical_primitive_types!(dtype, |$P| { use arrow2::array::Array; result_arrow_array diff --git a/src/daft-core/src/array/ops/image.rs b/src/daft-core/src/array/ops/image.rs index 37b3c92a00..0e6316cfb1 100644 --- a/src/daft-core/src/array/ops/image.rs +++ b/src/daft-core/src/array/ops/image.rs @@ -57,19 +57,17 @@ macro_rules! with_method_on_image_buffer { ( $key_type:expr, $method: ident ) => {{ - use DaftImageBuffer::*; - match $key_type { - L(img) => img.$method(), - LA(img) => img.$method(), - RGB(img) => img.$method(), - RGBA(img) => img.$method(), - L16(img) => img.$method(), - LA16(img) => img.$method(), - RGB16(img) => img.$method(), - RGBA16(img) => img.$method(), - RGB32F(img) => img.$method(), - RGBA32F(img) => img.$method(), + DaftImageBuffer::L(img) => img.$method(), + DaftImageBuffer::LA(img) => img.$method(), + DaftImageBuffer::RGB(img) => img.$method(), + DaftImageBuffer::RGBA(img) => img.$method(), + DaftImageBuffer::L16(img) => img.$method(), + DaftImageBuffer::LA16(img) => img.$method(), + DaftImageBuffer::RGB16(img) => img.$method(), + DaftImageBuffer::RGBA16(img) => img.$method(), + DaftImageBuffer::RGB32F(img) => img.$method(), + DaftImageBuffer::RGBA32F(img) => img.$method(), } }}; } @@ -148,19 +146,17 @@ impl From> for image::ImageFormat { impl From> for image::ColorType { fn from(image_mode: Wrap) -> image::ColorType { use image::ColorType; - use ImageMode::*; - match image_mode.0 { - L => ColorType::L8, - LA => ColorType::La8, - RGB => ColorType::Rgb8, - RGBA => ColorType::Rgba8, - L16 => ColorType::L16, - LA16 => ColorType::La16, - RGB16 => ColorType::Rgb16, - RGBA16 => ColorType::Rgba16, - RGB32F => ColorType::Rgb32F, - RGBA32F => ColorType::Rgba32F, + ImageMode::L => ColorType::L8, + ImageMode::LA => ColorType::La8, + ImageMode::RGB => ColorType::Rgb8, + ImageMode::RGBA => ColorType::Rgba8, + ImageMode::L16 => ColorType::L16, + ImageMode::LA16 => ColorType::La16, + ImageMode::RGB16 => ColorType::Rgb16, + ImageMode::RGBA16 => ColorType::Rgba16, + ImageMode::RGB32F => ColorType::Rgb32F, + ImageMode::RGBA32F => ColorType::Rgba32F, } } } @@ -170,19 +166,17 @@ impl TryFrom for Wrap { fn try_from(color: image::ColorType) -> DaftResult { use image::ColorType; - use ImageMode::*; - Ok(Wrap(match color { - ColorType::L8 => Ok(L), - ColorType::La8 => Ok(LA), - ColorType::Rgb8 => Ok(RGB), - ColorType::Rgba8 => Ok(RGBA), - ColorType::L16 => Ok(L16), - ColorType::La16 => Ok(LA16), - ColorType::Rgb16 => Ok(RGB16), - ColorType::Rgba16 => Ok(RGBA16), - ColorType::Rgb32F => Ok(RGB32F), - ColorType::Rgba32F => Ok(RGBA32F), + ColorType::L8 => Ok(ImageMode::L), + ColorType::La8 => Ok(ImageMode::LA), + ColorType::Rgb8 => Ok(ImageMode::RGB), + ColorType::Rgba8 => Ok(ImageMode::RGBA), + ColorType::L16 => Ok(ImageMode::L16), + ColorType::La16 => Ok(ImageMode::LA16), + ColorType::Rgb16 => Ok(ImageMode::RGB16), + ColorType::Rgba16 => Ok(ImageMode::RGBA16), + ColorType::Rgb32F => Ok(ImageMode::RGB32F), + ColorType::Rgba32F => Ok(ImageMode::RGBA32F), _ => Err(DaftError::ValueError(format!( "Color type {:?} is not supported.", color @@ -201,12 +195,11 @@ impl<'a> DaftImageBuffer<'a> { } pub fn as_u8_slice(&'a self) -> &'a [u8] { - use DaftImageBuffer::*; match self { - L(img) => img.as_raw(), - LA(img) => img.as_raw(), - RGB(img) => img.as_raw(), - RGBA(img) => img.as_raw(), + DaftImageBuffer::L(img) => img.as_raw(), + DaftImageBuffer::LA(img) => img.as_raw(), + DaftImageBuffer::RGB(img) => img.as_raw(), + DaftImageBuffer::RGBA(img) => img.as_raw(), _ => unimplemented!("unimplemented {self:?}"), } } @@ -216,19 +209,17 @@ impl<'a> DaftImageBuffer<'a> { } pub fn mode(&self) -> ImageMode { - use DaftImageBuffer::*; - match self { - L(..) => ImageMode::L, - LA(..) => ImageMode::LA, - RGB(..) => ImageMode::RGB, - RGBA(..) => ImageMode::RGBA, - L16(..) => ImageMode::L16, - LA16(..) => ImageMode::LA16, - RGB16(..) => ImageMode::RGB16, - RGBA16(..) => ImageMode::RGBA16, - RGB32F(..) => ImageMode::RGB32F, - RGBA32F(..) => ImageMode::RGBA32F, + DaftImageBuffer::L(..) => ImageMode::L, + DaftImageBuffer::LA(..) => ImageMode::LA, + DaftImageBuffer::RGB(..) => ImageMode::RGB, + DaftImageBuffer::RGBA(..) => ImageMode::RGBA, + DaftImageBuffer::L16(..) => ImageMode::L16, + DaftImageBuffer::LA16(..) => ImageMode::LA16, + DaftImageBuffer::RGB16(..) => ImageMode::RGB16, + DaftImageBuffer::RGBA16(..) => ImageMode::RGBA16, + DaftImageBuffer::RGB32F(..) => ImageMode::RGB32F, + DaftImageBuffer::RGBA32F(..) => ImageMode::RGBA32F, } } @@ -272,24 +263,23 @@ impl<'a> DaftImageBuffer<'a> { } pub fn resize(&self, w: u32, h: u32) -> Self { - use DaftImageBuffer::*; match self { - L(imgbuf) => { + DaftImageBuffer::L(imgbuf) => { let result = image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle); DaftImageBuffer::L(image_buffer_vec_to_cow(result)) } - LA(imgbuf) => { + DaftImageBuffer::LA(imgbuf) => { let result = image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle); DaftImageBuffer::LA(image_buffer_vec_to_cow(result)) } - RGB(imgbuf) => { + DaftImageBuffer::RGB(imgbuf) => { let result = image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle); DaftImageBuffer::RGB(image_buffer_vec_to_cow(result)) } - RGBA(imgbuf) => { + DaftImageBuffer::RGBA(imgbuf) => { let result = image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle); DaftImageBuffer::RGBA(image_buffer_vec_to_cow(result)) @@ -638,11 +628,15 @@ impl ImageArray { inputs: &[Option>], image_mode: &Option, ) -> DaftResult { - use DaftImageBuffer::*; - let is_all_u8 = inputs - .iter() - .filter_map(|b| b.as_ref()) - .all(|b| matches!(b, L(..) | LA(..) | RGB(..) | RGBA(..))); + let is_all_u8 = inputs.iter().filter_map(|b| b.as_ref()).all(|b| { + matches!( + b, + DaftImageBuffer::L(..) + | DaftImageBuffer::LA(..) + | DaftImageBuffer::RGB(..) + | DaftImageBuffer::RGBA(..) + ) + }); assert!(is_all_u8); let mut data_ref = Vec::with_capacity(inputs.len()); @@ -775,11 +769,15 @@ impl FixedShapeImageArray { height: u32, width: u32, ) -> DaftResult { - use DaftImageBuffer::*; - let is_all_u8 = inputs - .iter() - .filter_map(|b| b.as_ref()) - .all(|b| matches!(b, L(..) | LA(..) | RGB(..) | RGBA(..))); + let is_all_u8 = inputs.iter().filter_map(|b| b.as_ref()).all(|b| { + matches!( + b, + DaftImageBuffer::L(..) + | DaftImageBuffer::LA(..) + | DaftImageBuffer::RGB(..) + | DaftImageBuffer::RGBA(..) + ) + }); assert!(is_all_u8); let num_channels = image_mode.num_channels(); diff --git a/src/daft-core/src/array/ops/trigonometry.rs b/src/daft-core/src/array/ops/trigonometry.rs index 834e69b455..7bc45bab92 100644 --- a/src/daft-core/src/array/ops/trigonometry.rs +++ b/src/daft-core/src/array/ops/trigonometry.rs @@ -26,20 +26,19 @@ pub enum TrigonometricFunction { impl TrigonometricFunction { pub fn fn_name(&self) -> &'static str { - use TrigonometricFunction::*; match self { - Sin => "sin", - Cos => "cos", - Tan => "tan", - Cot => "cot", - ArcSin => "arcsin", - ArcCos => "arccos", - ArcTan => "arctan", - Radians => "radians", - Degrees => "degrees", - ArcTanh => "arctanh", - ArcCosh => "arccosh", - ArcSinh => "arcsinh", + TrigonometricFunction::Sin => "sin", + TrigonometricFunction::Cos => "cos", + TrigonometricFunction::Tan => "tan", + TrigonometricFunction::Cot => "cot", + TrigonometricFunction::ArcSin => "arcsin", + TrigonometricFunction::ArcCos => "arccos", + TrigonometricFunction::ArcTan => "arctan", + TrigonometricFunction::Radians => "radians", + TrigonometricFunction::Degrees => "degrees", + TrigonometricFunction::ArcTanh => "arctanh", + TrigonometricFunction::ArcCosh => "arccosh", + TrigonometricFunction::ArcSinh => "arcsinh", } } } @@ -50,20 +49,19 @@ where T::Native: Float, { pub fn trigonometry(&self, func: &TrigonometricFunction) -> DaftResult { - use TrigonometricFunction::*; match func { - Sin => self.apply(|v| v.sin()), - Cos => self.apply(|v| v.cos()), - Tan => self.apply(|v| v.tan()), - Cot => self.apply(|v| v.tan().powi(-1)), - ArcSin => self.apply(|v| v.asin()), - ArcCos => self.apply(|v| v.acos()), - ArcTan => self.apply(|v| v.atan()), - Radians => self.apply(|v| v.to_radians()), - Degrees => self.apply(|v| v.to_degrees()), - ArcTanh => self.apply(|v| v.atanh()), - ArcCosh => self.apply(|v| v.acosh()), - ArcSinh => self.apply(|v| v.asinh()), + TrigonometricFunction::Sin => self.apply(|v| v.sin()), + TrigonometricFunction::Cos => self.apply(|v| v.cos()), + TrigonometricFunction::Tan => self.apply(|v| v.tan()), + TrigonometricFunction::Cot => self.apply(|v| v.tan().powi(-1)), + TrigonometricFunction::ArcSin => self.apply(|v| v.asin()), + TrigonometricFunction::ArcCos => self.apply(|v| v.acos()), + TrigonometricFunction::ArcTan => self.apply(|v| v.atan()), + TrigonometricFunction::Radians => self.apply(|v| v.to_radians()), + TrigonometricFunction::Degrees => self.apply(|v| v.to_degrees()), + TrigonometricFunction::ArcTanh => self.apply(|v| v.atanh()), + TrigonometricFunction::ArcCosh => self.apply(|v| v.acosh()), + TrigonometricFunction::ArcSinh => self.apply(|v| v.asinh()), } } } diff --git a/src/daft-core/src/count_mode.rs b/src/daft-core/src/count_mode.rs index 3cecb11473..fd7343a564 100644 --- a/src/daft-core/src/count_mode.rs +++ b/src/daft-core/src/count_mode.rs @@ -42,9 +42,7 @@ impl_bincode_py_state_serialization!(CountMode); impl CountMode { pub fn iterator() -> std::slice::Iter<'static, CountMode> { - use CountMode::*; - - static COUNT_MODES: [CountMode; 3] = [All, Valid, Null]; + static COUNT_MODES: [CountMode; 3] = [CountMode::All, CountMode::Valid, CountMode::Null]; COUNT_MODES.iter() } } @@ -53,12 +51,10 @@ impl FromStr for CountMode { type Err = DaftError; fn from_str(count_mode: &str) -> DaftResult { - use CountMode::*; - match count_mode { - "all" => Ok(All), - "valid" => Ok(Valid), - "null" => Ok(Null), + "all" => Ok(CountMode::All), + "valid" => Ok(CountMode::Valid), + "null" => Ok(CountMode::Null), _ => Err(DaftError::TypeError(format!( "Count mode {} is not supported; only the following modes are supported: {:?}", count_mode, diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs index 38e5de1e30..a6420b039b 100644 --- a/src/daft-core/src/datatypes/agg_ops.rs +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -6,14 +6,15 @@ use super::DataType; /// Get the data type that the sum of a column of the given data type should be casted to. pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { - use DataType::*; match dtype { - Int8 | Int16 | Int32 | Int64 => Ok(Int64), - UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), - Float32 => Ok(Float32), - Float64 => Ok(Float64), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(DataType::Int64), + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), // 38 is the maximum precision for Decimal128, while 19 is the max increase based on 2^64 rows - Decimal128(a, b) => Ok(Decimal128(min(38, *a + 19), *b)), + DataType::Decimal128(a, b) => Ok(DataType::Decimal128(min(38, *a + 19), *b)), other => Err(DaftError::TypeError(format!( "Invalid argument to sum supertype: {}", other @@ -23,9 +24,8 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { /// Get the data type that the mean of a column of the given data type should be casted to. pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { - use DataType::*; if dtype.is_numeric() { - Ok(Float64) + Ok(DataType::Float64) } else { Err(DaftError::TypeError(format!( "Invalid argument to mean supertype: {}", diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index 730e8c6aeb..51679381a5 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -33,13 +33,14 @@ impl<'a> AsRef for InferDataType<'a> { impl<'a> InferDataType<'a> { pub fn logical_op(&self, other: &Self) -> DaftResult { // Whether a logical op (and, or, xor) is supported between the two types. - use DataType::*; let left = self.0; let other = other.0; match (left, other) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Boolean), - (Boolean, Boolean) | (Boolean, Null) | (Null, Boolean) => Ok(Boolean), + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Boolean), + (DataType::Boolean, DataType::Boolean) + | (DataType::Boolean, DataType::Null) + | (DataType::Null, DataType::Boolean) => Ok(DataType::Boolean), (s, o) if s.is_integer() && o.is_integer() => { let dtype = try_numeric_supertype(s, o)?; if dtype.is_floating() { @@ -72,32 +73,32 @@ impl<'a> InferDataType<'a> { let left = &self.0; let other = &other.0; - let evaluator = || { - use DataType::*; - match (left, other) { - (s, o) if s == o => Ok((Boolean, None, s.to_physical())), - (Utf8, o) | (o, Utf8) if o.is_numeric() => Err(DaftError::TypeError(format!( - "Cannot perform comparison on Utf8 and numeric type.\ntypes: {}, {}", - left, other - ))), - (s, o) if s.is_physical() && o.is_physical() => { - Ok((Boolean, None, try_physical_supertype(s, o)?)) - } - (Timestamp(..), Timestamp(..)) => { - let intermediate_type = try_get_supertype(left, other)?; - let pt = intermediate_type.to_physical(); - Ok((Boolean, Some(intermediate_type), pt)) - } - (Timestamp(..), Date) | (Date, Timestamp(..)) => { - let intermediate_type = Date; - let pt = intermediate_type.to_physical(); - Ok((Boolean, Some(intermediate_type), pt)) - } - _ => Err(DaftError::TypeError(format!( - "Cannot perform comparison on types: {}, {}", + let evaluator = || match (left, other) { + (s, o) if s == o => Ok((DataType::Boolean, None, s.to_physical())), + (DataType::Utf8, o) | (o, DataType::Utf8) if o.is_numeric() => { + Err(DaftError::TypeError(format!( + "Cannot perform comparison on DataType::Utf8 and numeric type.\ntypes: {}, {}", left, other - ))), + ))) + } + (s, o) if s.is_physical() && o.is_physical() => { + Ok((DataType::Boolean, None, try_physical_supertype(s, o)?)) + } + (DataType::Timestamp(..), DataType::Timestamp(..)) => { + let intermediate_type = try_get_supertype(left, other)?; + let pt = intermediate_type.to_physical(); + Ok((DataType::Boolean, Some(intermediate_type), pt)) } + (DataType::Timestamp(..), DataType::Date) + | (DataType::Date, DataType::Timestamp(..)) => { + let intermediate_type = DataType::Date; + let pt = intermediate_type.to_physical(); + Ok((DataType::Boolean, Some(intermediate_type), pt)) + } + _ => Err(DaftError::TypeError(format!( + "Cannot perform comparison on types: {}, {}", + left, other + ))), }; evaluator().map_err(|err| { @@ -120,30 +121,28 @@ impl<'a> Add for InferDataType<'a> { type Output = DaftResult; fn add(self, other: Self) -> Self::Output { - use DataType::*; - try_numeric_supertype(self.0, other.0).or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {InferDataType::from(l) + InferDataType::from(r)})).or( match (self.0, other.0) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Python), - (Timestamp(t_unit, tz), Duration(d_unit)) - | (Duration(d_unit), Timestamp(t_unit, tz)) - if t_unit == d_unit => Ok(Timestamp(*t_unit, tz.clone())), - (ts @ Timestamp(..), du @ Duration(..)) - | (du @ Duration(..), ts @ Timestamp(..)) => Err(DaftError::TypeError( + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + (DataType::Timestamp(t_unit, tz), DataType::Duration(d_unit)) + | (DataType::Duration(d_unit), DataType::Timestamp(t_unit, tz)) + if t_unit == d_unit => Ok(DataType::Timestamp(*t_unit, tz.clone())), + (ts @ DataType::Timestamp(..), du @ DataType::Duration(..)) + | (du @ DataType::Duration(..), ts @ DataType::Timestamp(..)) => Err(DaftError::TypeError( format!("Cannot add due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", ts, du) )), - (Date, Duration(..)) | (Duration(..), Date) => Ok(Date), - (Duration(d_unit_self), Duration(d_unit_other)) if d_unit_self == d_unit_other => { - Ok(Duration(*d_unit_self)) + (DataType::Date, DataType::Duration(..)) | (DataType::Duration(..), DataType::Date) => Ok(DataType::Date), + (DataType::Duration(d_unit_self), DataType::Duration(d_unit_other)) if d_unit_self == d_unit_other => { + Ok(DataType::Duration(*d_unit_self)) }, - (du_self @ &Duration(..), du_other @ &Duration(..)) => Err(DaftError::TypeError( + (du_self @ &DataType::Duration(..), du_other @ &DataType::Duration(..)) => Err(DaftError::TypeError( format!("Cannot add due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", du_self, du_other) )), - (Null, other) | (other, Null) => { + (DataType::Null, other) | (other, DataType::Null) => { match other { // Condition is for backwards compatibility. TODO: remove - Binary | FixedSizeBinary(..) | Date => Err(DaftError::TypeError( + DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Date => Err(DaftError::TypeError( format!("Cannot add types: {}, {}", self, other) )), other if other.is_physical() => Ok(other.clone()), @@ -152,19 +151,19 @@ impl<'a> Add for InferDataType<'a> { )), } } - (Utf8, other) | (other, Utf8) => { + (DataType::Utf8, other) | (other, DataType::Utf8) => { match other { - // Date condition is for backwards compatibility. TODO: remove - Binary | FixedSizeBinary(..) | Date => Err(DaftError::TypeError( + // DataType::Date condition is for backwards compatibility. TODO: remove + DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Date => Err(DaftError::TypeError( format!("Cannot add types: {}, {}", self, other) )), - other if other.is_physical() => Ok(Utf8), + other if other.is_physical() => Ok(DataType::Utf8), _ => Err(DaftError::TypeError( format!("Cannot add types: {}, {}", self, other) )), } } - (Boolean, other) | (other, Boolean) + (DataType::Boolean, other) | (other, DataType::Boolean) if other.is_numeric() => Ok(other.clone()), _ => Err(DaftError::TypeError( format!("Cannot add types: {}, {}", self, other) @@ -178,27 +177,26 @@ impl<'a> Sub for InferDataType<'a> { type Output = DaftResult; fn sub(self, other: Self) -> Self::Output { - use DataType::*; try_numeric_supertype(self.0, other.0).or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {InferDataType::from(l) - InferDataType::from(r)})).or( match (self.0, other.0) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Python), - (Timestamp(t_unit, tz), Duration(d_unit)) - if t_unit == d_unit => Ok(Timestamp(*t_unit, tz.clone())), - (ts @ Timestamp(..), du @ Duration(..)) => Err(DaftError::TypeError( + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + (DataType::Timestamp(t_unit, tz), DataType::Duration(d_unit)) + if t_unit == d_unit => Ok(DataType::Timestamp(*t_unit, tz.clone())), + (ts @ DataType::Timestamp(..), du @ DataType::Duration(..)) => Err(DaftError::TypeError( format!("Cannot subtract due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", ts, du) )), - (Timestamp(t_unit_self, tz_self), Timestamp(t_unit_other, tz_other)) - if t_unit_self == t_unit_other && tz_self == tz_other => Ok(Duration(*t_unit_self)), - (ts @ Timestamp(..), ts_other @ Timestamp(..)) => Err(DaftError::TypeError( + (DataType::Timestamp(t_unit_self, tz_self), DataType::Timestamp(t_unit_other, tz_other)) + if t_unit_self == t_unit_other && tz_self == tz_other => Ok(DataType::Duration(*t_unit_self)), + (ts @ DataType::Timestamp(..), ts_other @ DataType::Timestamp(..)) => Err(DaftError::TypeError( format!("Cannot subtract due to differing precision or timezone: {}, {}. Please explicitly cast to the precision or timezone you wish to add in.", ts, ts_other) )), - (Date, Duration(..)) => Ok(Date), - (Date, Date) => Ok(Duration(crate::datatypes::TimeUnit::Seconds)), - (Duration(d_unit_self), Duration(d_unit_other)) if d_unit_self == d_unit_other => { - Ok(Duration(*d_unit_self)) + (DataType::Date, DataType::Duration(..)) => Ok(DataType::Date), + (DataType::Date, DataType::Date) => Ok(DataType::Duration(crate::datatypes::TimeUnit::Seconds)), + (DataType::Duration(d_unit_self), DataType::Duration(d_unit_other)) if d_unit_self == d_unit_other => { + Ok(DataType::Duration(*d_unit_self)) }, - (du_self @ &Duration(..), du_other @ &Duration(..)) => Err(DaftError::TypeError( + (du_self @ &DataType::Duration(..), du_other @ &DataType::Duration(..)) => Err(DaftError::TypeError( format!("Cannot subtract due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", du_self, du_other) )), _ => Err(DaftError::TypeError( @@ -213,11 +211,10 @@ impl<'a> Div for InferDataType<'a> { type Output = DaftResult; fn div(self, other: Self) -> Self::Output { - use DataType::*; match (&self.0, &other.0) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Python), - (s, o) if s.is_numeric() && o.is_numeric() => Ok(Float64), + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + (s, o) if s.is_numeric() && o.is_numeric() => Ok(DataType::Float64), _ => Err(DaftError::TypeError(format!( "Cannot divide types: {}, {}", self, other @@ -233,14 +230,13 @@ impl<'a> Mul for InferDataType<'a> { type Output = DaftResult; fn mul(self, other: Self) -> Self::Output { - use DataType::*; try_numeric_supertype(self.0, other.0) .or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { InferDataType::from(l) * InferDataType::from(r) })) .or(match (self.0, other.0) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Python), + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), _ => Err(DaftError::TypeError(format!( "Cannot multiply types: {}, {}", self, other @@ -253,14 +249,13 @@ impl<'a> Rem for InferDataType<'a> { type Output = DaftResult; fn rem(self, other: Self) -> Self::Output { - use DataType::*; try_numeric_supertype(self.0, other.0) .or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { InferDataType::from(l) % InferDataType::from(r) })) .or(match (self.0, other.0) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Python), + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), _ => Err(DaftError::TypeError(format!( "Cannot multiply types: {}, {}", self, other @@ -301,14 +296,20 @@ pub fn try_physical_supertype(l: &DataType, r: &DataType) -> DaftResult Ok(other.clone()), - (Boolean, other) | (other, Boolean) if other.is_numeric() => Ok(other.clone()), + (DataType::Null, other) | (other, DataType::Null) if other.is_physical() => { + Ok(other.clone()) + } + (DataType::Boolean, other) | (other, DataType::Boolean) if other.is_numeric() => { + Ok(other.clone()) + } #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(Python), - (Utf8, o) | (o, Utf8) if o.is_physical() && !matches!(o, Binary | FixedSizeBinary(..)) => { - Ok(Utf8) + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + (DataType::Utf8, o) | (o, DataType::Utf8) + if o.is_physical() + && !matches!(o, DataType::Binary | DataType::FixedSizeBinary(..)) => + { + Ok(DataType::Utf8) } _ => Err(DaftError::TypeError(format!( "Invalid arguments to try_physical_supertype: {}, {}", @@ -323,73 +324,71 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult // for the purpose of performing numeric operations. fn inner(l: &DataType, r: &DataType) -> Option { - use DataType::*; - match (l, r) { - (Int8, Int8) => Some(Int8), - (Int8, Int16) => Some(Int16), - (Int8, Int32) => Some(Int32), - (Int8, Int64) => Some(Int64), - (Int8, UInt8) => Some(Int16), - (Int8, UInt16) => Some(Int32), - (Int8, UInt32) => Some(Int64), - (Int8, UInt64) => Some(Float64), // Follow numpy - (Int8, Float32) => Some(Float32), - (Int8, Float64) => Some(Float64), - - (Int16, Int16) => Some(Int16), - (Int16, Int32) => Some(Int32), - (Int16, Int64) => Some(Int64), - (Int16, UInt8) => Some(Int16), - (Int16, UInt16) => Some(Int32), - (Int16, UInt32) => Some(Int64), - (Int16, UInt64) => Some(Float64), // Follow numpy - (Int16, Float32) => Some(Float32), - (Int16, Float64) => Some(Float64), - - (Int32, Int32) => Some(Int32), - (Int32, Int64) => Some(Int64), - (Int32, UInt8) => Some(Int32), - (Int32, UInt16) => Some(Int32), - (Int32, UInt32) => Some(Int64), - (Int32, UInt64) => Some(Float64), // Follow numpy - (Int32, Float32) => Some(Float64), // Follow numpy - (Int32, Float64) => Some(Float64), - - (Int64, Int64) => Some(Int64), - (Int64, UInt8) => Some(Int64), - (Int64, UInt16) => Some(Int64), - (Int64, UInt32) => Some(Int64), - (Int64, UInt64) => Some(Float64), // Follow numpy - (Int64, Float32) => Some(Float64), // Follow numpy - (Int64, Float64) => Some(Float64), - - (UInt8, UInt8) => Some(UInt8), - (UInt8, UInt16) => Some(UInt16), - (UInt8, UInt32) => Some(UInt32), - (UInt8, UInt64) => Some(UInt64), - (UInt8, Float32) => Some(Float32), - (UInt8, Float64) => Some(Float64), - - (UInt16, UInt16) => Some(UInt16), - (UInt16, UInt32) => Some(UInt32), - (UInt16, UInt64) => Some(UInt64), - (UInt16, Float32) => Some(Float32), - (UInt16, Float64) => Some(Float64), - - (UInt32, UInt32) => Some(UInt32), - (UInt32, UInt64) => Some(UInt64), - (UInt32, Float32) => Some(Float64), - (UInt32, Float64) => Some(Float64), - - (UInt64, UInt64) => Some(UInt64), - (UInt64, Float32) => Some(Float64), - (UInt64, Float64) => Some(Float64), - - (Float32, Float32) => Some(Float32), - (Float32, Float64) => Some(Float64), - - (Float64, Float64) => Some(Float64), + (DataType::Int8, DataType::Int8) => Some(DataType::Int8), + (DataType::Int8, DataType::Int16) => Some(DataType::Int16), + (DataType::Int8, DataType::Int32) => Some(DataType::Int32), + (DataType::Int8, DataType::Int64) => Some(DataType::Int64), + (DataType::Int8, DataType::UInt8) => Some(DataType::Int16), + (DataType::Int8, DataType::UInt16) => Some(DataType::Int32), + (DataType::Int8, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int8, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int8, DataType::Float32) => Some(DataType::Float32), + (DataType::Int8, DataType::Float64) => Some(DataType::Float64), + + (DataType::Int16, DataType::Int16) => Some(DataType::Int16), + (DataType::Int16, DataType::Int32) => Some(DataType::Int32), + (DataType::Int16, DataType::Int64) => Some(DataType::Int64), + (DataType::Int16, DataType::UInt8) => Some(DataType::Int16), + (DataType::Int16, DataType::UInt16) => Some(DataType::Int32), + (DataType::Int16, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int16, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int16, DataType::Float32) => Some(DataType::Float32), + (DataType::Int16, DataType::Float64) => Some(DataType::Float64), + + (DataType::Int32, DataType::Int32) => Some(DataType::Int32), + (DataType::Int32, DataType::Int64) => Some(DataType::Int64), + (DataType::Int32, DataType::UInt8) => Some(DataType::Int32), + (DataType::Int32, DataType::UInt16) => Some(DataType::Int32), + (DataType::Int32, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int32, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int32, DataType::Float32) => Some(DataType::Float64), // Follow numpy + (DataType::Int32, DataType::Float64) => Some(DataType::Float64), + + (DataType::Int64, DataType::Int64) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt8) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt16) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int64, DataType::Float32) => Some(DataType::Float64), // Follow numpy + (DataType::Int64, DataType::Float64) => Some(DataType::Float64), + + (DataType::UInt8, DataType::UInt8) => Some(DataType::UInt8), + (DataType::UInt8, DataType::UInt16) => Some(DataType::UInt16), + (DataType::UInt8, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt8, DataType::UInt64) => Some(DataType::UInt64), + (DataType::UInt8, DataType::Float32) => Some(DataType::Float32), + (DataType::UInt8, DataType::Float64) => Some(DataType::Float64), + + (DataType::UInt16, DataType::UInt16) => Some(DataType::UInt16), + (DataType::UInt16, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt16, DataType::UInt64) => Some(DataType::UInt64), + (DataType::UInt16, DataType::Float32) => Some(DataType::Float32), + (DataType::UInt16, DataType::Float64) => Some(DataType::Float64), + + (DataType::UInt32, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt64) => Some(DataType::UInt64), + (DataType::UInt32, DataType::Float32) => Some(DataType::Float64), + (DataType::UInt32, DataType::Float64) => Some(DataType::Float64), + + (DataType::UInt64, DataType::UInt64) => Some(DataType::UInt64), + (DataType::UInt64, DataType::Float32) => Some(DataType::Float64), + (DataType::UInt64, DataType::Float64) => Some(DataType::Float64), + + (DataType::Float32, DataType::Float32) => Some(DataType::Float32), + (DataType::Float32, DataType::Float64) => Some(DataType::Float64), + + (DataType::Float64, DataType::Float64) => Some(DataType::Float64), _ => None, } @@ -411,10 +410,11 @@ pub fn try_fixed_shape_numeric_datatype( where F: Fn(&DataType, &DataType) -> DaftResult, { - use DataType::*; - match (l, r) { - (FixedShapeTensor(ldtype, lshape), FixedShapeTensor(rdtype, rshape)) => { + ( + DataType::FixedShapeTensor(ldtype, lshape), + DataType::FixedShapeTensor(rdtype, rshape), + ) => { if lshape != rshape { Err(DaftError::TypeError(format!( "Cannot add types: {}, {} due to shape mismatch", @@ -423,7 +423,10 @@ where } else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref()) && result_type.is_numeric() { - Ok(FixedShapeTensor(Box::new(result_type), lshape.clone())) + Ok(DataType::FixedShapeTensor( + Box::new(result_type), + lshape.clone(), + )) } else { Err(DaftError::TypeError(format!( "Cannot add types: {}, {}", @@ -431,14 +434,14 @@ where ))) } } - (FixedSizeList(ldtype, lsize), FixedSizeList(rdtype, rsize)) => { + (DataType::FixedSizeList(ldtype, lsize), DataType::FixedSizeList(rdtype, rsize)) => { if lsize != rsize { Err(DaftError::TypeError(format!( "Cannot add types: {}, {} due to shape mismatch", l, r ))) } else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref()) { - Ok(FixedSizeList(Box::new(result_type), *lsize)) + Ok(DataType::FixedSizeList(Box::new(result_type), *lsize)) } else { Err(DaftError::TypeError(format!( "Cannot add types: {}, {}", @@ -446,7 +449,7 @@ where ))) } } - (Embedding(ldtype, lsize), Embedding(rdtype, rsize)) => { + (DataType::Embedding(ldtype, lsize), DataType::Embedding(rdtype, rsize)) => { if lsize != rsize { Err(DaftError::TypeError(format!( "Cannot add types: {}, {} due to shape mismatch", @@ -455,7 +458,7 @@ where } else if let Ok(result_type) = inner_f(ldtype.as_ref(), rdtype.as_ref()) && result_type.is_numeric() { - Ok(Embedding(Box::new(result_type), *lsize)) + Ok(DataType::Embedding(Box::new(result_type), *lsize)) } else { Err(DaftError::TypeError(format!( "Cannot add types: {}, {}", diff --git a/src/daft-core/src/join.rs b/src/daft-core/src/join.rs index d10d6d59f6..365d141757 100644 --- a/src/daft-core/src/join.rs +++ b/src/daft-core/src/join.rs @@ -40,9 +40,14 @@ impl_bincode_py_state_serialization!(JoinType); impl JoinType { pub fn iterator() -> std::slice::Iter<'static, JoinType> { - use JoinType::*; - - static JOIN_TYPES: [JoinType; 6] = [Inner, Left, Right, Outer, Anti, Semi]; + static JOIN_TYPES: [JoinType; 6] = [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Outer, + JoinType::Anti, + JoinType::Semi, + ]; JOIN_TYPES.iter() } } @@ -51,15 +56,13 @@ impl FromStr for JoinType { type Err = DaftError; fn from_str(join_type: &str) -> DaftResult { - use JoinType::*; - match join_type { - "inner" => Ok(Inner), - "left" => Ok(Left), - "right" => Ok(Right), - "outer" => Ok(Outer), - "anti" => Ok(Anti), - "semi" => Ok(Semi), + "inner" => Ok(JoinType::Inner), + "left" => Ok(JoinType::Left), + "right" => Ok(JoinType::Right), + "outer" => Ok(JoinType::Outer), + "anti" => Ok(JoinType::Anti), + "semi" => Ok(JoinType::Semi), _ => Err(DaftError::TypeError(format!( "Join type {} is not supported; only the following types are supported: {:?}", join_type, @@ -97,9 +100,11 @@ impl_bincode_py_state_serialization!(JoinStrategy); impl JoinStrategy { pub fn iterator() -> std::slice::Iter<'static, JoinStrategy> { - use JoinStrategy::*; - - static JOIN_STRATEGIES: [JoinStrategy; 3] = [Hash, SortMerge, Broadcast]; + static JOIN_STRATEGIES: [JoinStrategy; 3] = [ + JoinStrategy::Hash, + JoinStrategy::SortMerge, + JoinStrategy::Broadcast, + ]; JOIN_STRATEGIES.iter() } } @@ -108,12 +113,10 @@ impl FromStr for JoinStrategy { type Err = DaftError; fn from_str(join_strategy: &str) -> DaftResult { - use JoinStrategy::*; - match join_strategy { - "hash" => Ok(Hash), - "sort_merge" => Ok(SortMerge), - "broadcast" => Ok(Broadcast), + "hash" => Ok(JoinStrategy::Hash), + "sort_merge" => Ok(JoinStrategy::SortMerge), + "broadcast" => Ok(JoinStrategy::Broadcast), _ => Err(DaftError::TypeError(format!( "Join strategy {} is not supported; only the following strategies are supported: {:?}", join_strategy, diff --git a/src/daft-core/src/kernels/hashing.rs b/src/daft-core/src/kernels/hashing.rs index c7151c10ce..293f4bc803 100644 --- a/src/daft-core/src/kernels/hashing.rs +++ b/src/daft-core/src/kernels/hashing.rs @@ -175,18 +175,21 @@ pub fn hash(array: &dyn Array, seed: Option<&PrimitiveArray>) -> Result hash_null(array.as_any().downcast_ref().unwrap(), seed), - Boolean => hash_boolean(array.as_any().downcast_ref().unwrap(), seed), - Primitive(primitive) => with_match_hashing_primitive_type!(primitive, |$T| { + PhysicalType::Null => hash_null(array.as_any().downcast_ref().unwrap(), seed), + PhysicalType::Boolean => hash_boolean(array.as_any().downcast_ref().unwrap(), seed), + PhysicalType::Primitive(primitive) => with_match_hashing_primitive_type!(primitive, |$T| { hash_primitive::<$T>(array.as_any().downcast_ref().unwrap(), seed) }), - Binary => hash_binary::(array.as_any().downcast_ref().unwrap(), seed), - LargeBinary => hash_binary::(array.as_any().downcast_ref().unwrap(), seed), - FixedSizeBinary => hash_fixed_size_binary(array.as_any().downcast_ref().unwrap(), seed), - Utf8 => hash_utf8::(array.as_any().downcast_ref().unwrap(), seed), - LargeUtf8 => hash_utf8::(array.as_any().downcast_ref().unwrap(), seed), + PhysicalType::Binary => hash_binary::(array.as_any().downcast_ref().unwrap(), seed), + PhysicalType::LargeBinary => { + hash_binary::(array.as_any().downcast_ref().unwrap(), seed) + } + PhysicalType::FixedSizeBinary => { + hash_fixed_size_binary(array.as_any().downcast_ref().unwrap(), seed) + } + PhysicalType::Utf8 => hash_utf8::(array.as_any().downcast_ref().unwrap(), seed), + PhysicalType::LargeUtf8 => hash_utf8::(array.as_any().downcast_ref().unwrap(), seed), t => { return Err(Error::NotYetImplemented(format!( "Hash not implemented for type {t:?}" diff --git a/src/daft-core/src/kernels/search_sorted.rs b/src/daft-core/src/kernels/search_sorted.rs index 8290adc796..f8b0a0f946 100644 --- a/src/daft-core/src/kernels/search_sorted.rs +++ b/src/daft-core/src/kernels/search_sorted.rs @@ -571,7 +571,6 @@ pub fn search_sorted( keys: &dyn Array, input_reversed: bool, ) -> Result> { - use PhysicalType::*; if sorted_array.data_type() != keys.data_type() { let error_string = format!( "sorted array data type does not match keys data type: {:?} vs {:?}", @@ -582,35 +581,37 @@ pub fn search_sorted( } Ok(match sorted_array.data_type().to_physical_type() { // Boolean => hash_boolean(array.as_any().downcast_ref().unwrap()), - Primitive(primitive) => with_match_searching_primitive_type!(primitive, |$T| { - search_sorted_primitive_array::<$T>(sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed) - }), - Utf8 => search_sorted_utf_array::( + PhysicalType::Primitive(primitive) => { + with_match_searching_primitive_type!(primitive, |$T| { + search_sorted_primitive_array::<$T>(sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed) + }) + } + PhysicalType::Utf8 => search_sorted_utf_array::( sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed, ), - LargeUtf8 => search_sorted_utf_array::( + PhysicalType::LargeUtf8 => search_sorted_utf_array::( sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed, ), - Binary => search_sorted_binary_array::( + PhysicalType::Binary => search_sorted_binary_array::( sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed, ), - LargeBinary => search_sorted_binary_array::( + PhysicalType::LargeBinary => search_sorted_binary_array::( sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed, ), - FixedSizeBinary => search_sorted_fixed_size_binary_array( + PhysicalType::FixedSizeBinary => search_sorted_fixed_size_binary_array( sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed, ), - Boolean => search_sorted_boolean_array( + PhysicalType::Boolean => search_sorted_boolean_array( sorted_array.as_any().downcast_ref().unwrap(), keys.as_any().downcast_ref().unwrap(), input_reversed, diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index 2e3821df3a..71190ed70a 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -125,10 +125,9 @@ macro_rules! py_numeric_binary_op { let output_type = InferDataType::from($self.data_type()).$op(InferDataType::from($rhs.data_type()))?; let lhs = $self.into_series(); - use DataType::*; match &output_type { #[cfg(feature = "python")] - Python => Ok(py_binary_op!(lhs, $rhs, $pyop)), + DataType::Python => Ok(py_binary_op!(lhs, $rhs, $pyop)), output_type if output_type.is_numeric() => { with_match_numeric_daft_types!(output_type, |$T| { cast_downcast_op_into_series!( @@ -153,13 +152,16 @@ macro_rules! physical_logic_op { let output_type = InferDataType::from($self.data_type()) .logical_op(&InferDataType::from($rhs.data_type()))?; let lhs = $self.into_series(); - use DataType::*; match &output_type { #[cfg(feature = "python")] - Boolean => match (&lhs.data_type(), &$rhs.data_type()) { + DataType::Boolean => match (&lhs.data_type(), &$rhs.data_type()) { #[cfg(feature = "python")] - (Python, _) | (_, Python) => Ok(py_binary_op_bool!(lhs, $rhs, $pyop)), - _ => cast_downcast_op_into_series!(lhs, $rhs, &Boolean, BooleanArray, $op), + (DataType::Python, _) | (_, DataType::Python) => { + Ok(py_binary_op_bool!(lhs, $rhs, $pyop)) + } + _ => { + cast_downcast_op_into_series!(lhs, $rhs, &DataType::Boolean, BooleanArray, $op) + } }, output_type if output_type.is_integer() => { with_match_integer_daft_types!(output_type, |$T| { @@ -188,11 +190,10 @@ macro_rules! physical_compare_op { (lhs, $rhs.clone()) }; - use DataType::*; - if let Boolean = output_type { + if let DataType::Boolean = output_type { match comp_type { #[cfg(feature = "python")] - Python => py_binary_op_bool!(lhs, rhs, $pyop) + DataType::Python => py_binary_op_bool!(lhs, rhs, $pyop) .downcast::() .cloned(), _ => with_match_comparable_daft_types!(comp_type, |$T| { @@ -210,11 +211,12 @@ pub(crate) trait SeriesBinaryOps: SeriesLike { let output_type = InferDataType::from(self.data_type()).add(InferDataType::from(rhs.data_type()))?; let lhs = self.into_series(); - use DataType::*; match &output_type { #[cfg(feature = "python")] - Python => Ok(py_binary_op!(lhs, rhs, "add")), - Utf8 => cast_downcast_op_into_series!(lhs, rhs, &Utf8, Utf8Array, add), + DataType::Python => Ok(py_binary_op!(lhs, rhs, "add")), + DataType::Utf8 => { + cast_downcast_op_into_series!(lhs, rhs, &DataType::Utf8, Utf8Array, add) + } output_type if output_type.is_numeric() => { with_match_numeric_daft_types!(output_type, |$T| { cast_downcast_op_into_series!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, add) @@ -236,11 +238,12 @@ pub(crate) trait SeriesBinaryOps: SeriesLike { let output_type = InferDataType::from(self.data_type()).div(InferDataType::from(rhs.data_type()))?; let lhs = self.into_series(); - use DataType::*; match &output_type { #[cfg(feature = "python")] - Python => Ok(py_binary_op!(lhs, rhs, "truediv")), - Float64 => cast_downcast_op_into_series!(lhs, rhs, &Float64, Float64Array, div), + DataType::Python => Ok(py_binary_op!(lhs, rhs, "truediv")), + DataType::Float64 => { + cast_downcast_op_into_series!(lhs, rhs, &DataType::Float64, Float64Array, div) + } output_type if output_type.is_fixed_size_numeric() => { fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, div) } @@ -305,11 +308,10 @@ impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper { fn add(&self, rhs: &Series) -> DaftResult { - use DataType::*; let output_type = (InferDataType::from(self.data_type()) + InferDataType::from(rhs.data_type()))?; match rhs.data_type() { - Duration(..) => { + DataType::Duration(..) => { let days = rhs.duration()?.cast_to_days()?; let physical_result = self.0.physical.add(&days)?; physical_result.cast(&output_type) @@ -318,15 +320,14 @@ impl SeriesBinaryOps for ArrayWrapper { } } fn sub(&self, rhs: &Series) -> DaftResult { - use DataType::*; let output_type = (InferDataType::from(self.data_type()) - InferDataType::from(rhs.data_type()))?; match rhs.data_type() { - Date => { + DataType::Date => { let physical_result = self.0.physical.sub(&rhs.date()?.physical)?; physical_result.cast(&output_type) } - Duration(..) => { + DataType::Duration(..) => { let days = rhs.duration()?.cast_to_days()?; let physical_result = self.0.physical.sub(&days)?; physical_result.cast(&output_type) @@ -338,20 +339,19 @@ impl SeriesBinaryOps for ArrayWrapper { impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper { fn add(&self, rhs: &Series) -> DaftResult { - use DataType::*; let output_type = (InferDataType::from(self.data_type()) + InferDataType::from(rhs.data_type()))?; let lhs = self.0.clone().into_series(); match rhs.data_type() { - Timestamp(..) => { + DataType::Timestamp(..) => { let physical_result = self.0.physical.add(&rhs.timestamp()?.physical)?; physical_result.cast(&output_type) } - Duration(..) => { + DataType::Duration(..) => { let physical_result = self.0.physical.add(&rhs.duration()?.physical)?; physical_result.cast(&output_type) } - Date => { + DataType::Date => { let days = self.0.cast_to_days()?; let physical_result = days.add(&rhs.date()?.physical)?; physical_result.cast(&output_type) @@ -361,11 +361,10 @@ impl SeriesBinaryOps for ArrayWrapper { } fn sub(&self, rhs: &Series) -> DaftResult { - use DataType::*; let output_type = (InferDataType::from(self.data_type()) - InferDataType::from(rhs.data_type()))?; match rhs.data_type() { - Duration(..) => { + DataType::Duration(..) => { let physical_result = self.0.physical.sub(&rhs.duration()?.physical)?; physical_result.cast(&output_type) } @@ -376,11 +375,10 @@ impl SeriesBinaryOps for ArrayWrapper { impl SeriesBinaryOps for ArrayWrapper { fn add(&self, rhs: &Series) -> DaftResult { - use DataType::*; let output_type = (InferDataType::from(self.data_type()) + InferDataType::from(rhs.data_type()))?; match rhs.data_type() { - Duration(..) => { + DataType::Duration(..) => { let physical_result = self.0.physical.add(&rhs.duration()?.physical)?; physical_result.cast(&output_type) } @@ -388,15 +386,14 @@ impl SeriesBinaryOps for ArrayWrapper { } } fn sub(&self, rhs: &Series) -> DaftResult { - use DataType::*; let output_type = (InferDataType::from(self.data_type()) - InferDataType::from(rhs.data_type()))?; match rhs.data_type() { - Duration(..) => { + DataType::Duration(..) => { let physical_result = self.0.physical.sub(&rhs.duration()?.physical)?; physical_result.cast(&output_type) } - Timestamp(..) => { + DataType::Timestamp(..) => { let physical_result = self.0.physical.sub(&rhs.timestamp()?.physical)?; physical_result.cast(&output_type) } diff --git a/src/daft-core/src/series/ops/abs.rs b/src/daft-core/src/series/ops/abs.rs index 99d8c70742..da3d092ca7 100644 --- a/src/daft-core/src/series/ops/abs.rs +++ b/src/daft-core/src/series/ops/abs.rs @@ -1,20 +1,21 @@ use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftError; use common_error::DaftResult; + impl Series { pub fn abs(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - - use DataType::*; match self.data_type() { - Int8 => Ok(self.i8().unwrap().abs()?.into_series()), - Int16 => Ok(self.i16().unwrap().abs()?.into_series()), - Int32 => Ok(self.i32().unwrap().abs()?.into_series()), - Int64 => Ok(self.i64().unwrap().abs()?.into_series()), - UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()), - Float32 => Ok(self.f32().unwrap().abs()?.into_series()), - Float64 => Ok(self.f64().unwrap().abs()?.into_series()), + DataType::Int8 => Ok(self.i8().unwrap().abs()?.into_series()), + DataType::Int16 => Ok(self.i16().unwrap().abs()?.into_series()), + DataType::Int32 => Ok(self.i32().unwrap().abs()?.into_series()), + DataType::Int64 => Ok(self.i64().unwrap().abs()?.into_series()), + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(self.clone()) + } + DataType::Float32 => Ok(self.f32().unwrap().abs()?.into_series()), + DataType::Float64 => Ok(self.f64().unwrap().abs()?.into_series()), dt => Err(DaftError::TypeError(format!( "abs not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/ceil.rs b/src/daft-core/src/series/ops/ceil.rs index 3a01f045a6..17cb55beae 100644 --- a/src/daft-core/src/series/ops/ceil.rs +++ b/src/daft-core/src/series/ops/ceil.rs @@ -1,16 +1,22 @@ use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftError; use common_error::DaftResult; + impl Series { pub fn ceil(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()), - Float32 => Ok(self.f32().unwrap().ceil()?.into_series()), - Float64 => Ok(self.f64().unwrap().ceil()?.into_series()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(self.clone()), + DataType::Float32 => Ok(self.f32().unwrap().ceil()?.into_series()), + DataType::Float64 => Ok(self.f64().unwrap().ceil()?.into_series()), dt => Err(DaftError::TypeError(format!( "ceil not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/exp.rs b/src/daft-core/src/series/ops/exp.rs index 9959a4cf36..f52c44ff2d 100644 --- a/src/daft-core/src/series/ops/exp.rs +++ b/src/daft-core/src/series/ops/exp.rs @@ -2,17 +2,15 @@ use common_error::DaftError; use common_error::DaftResult; use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; impl Series { pub fn exp(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - - use DataType::*; match self.data_type() { - Float32 => Ok(self.f32().unwrap().exp()?.into_series()), - Float64 => Ok(self.f64().unwrap().exp()?.into_series()), - dt if dt.is_integer() => self.cast(&Float64).unwrap().exp(), + DataType::Float32 => Ok(self.f32().unwrap().exp()?.into_series()), + DataType::Float64 => Ok(self.f64().unwrap().exp()?.into_series()), + dt if dt.is_integer() => self.cast(&DataType::Float64).unwrap().exp(), dt => Err(DaftError::TypeError(format!( "exp not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/floor.rs b/src/daft-core/src/series/ops/floor.rs index 59256fef5f..59ae87fcf1 100644 --- a/src/daft-core/src/series/ops/floor.rs +++ b/src/daft-core/src/series/ops/floor.rs @@ -1,16 +1,22 @@ use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftError; use common_error::DaftResult; impl Series { pub fn floor(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()), - Float32 => Ok(self.f32().unwrap().floor()?.into_series()), - Float64 => Ok(self.f64().unwrap().floor()?.into_series()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(self.clone()), + DataType::Float32 => Ok(self.f32().unwrap().floor()?.into_series()), + DataType::Float64 => Ok(self.f64().unwrap().floor()?.into_series()), dt => Err(DaftError::TypeError(format!( "floor not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 57f74b70ff..5590e74ef8 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -8,10 +8,9 @@ use common_error::DaftResult; impl Series { pub fn explode(&self) -> DaftResult { - use DataType::*; match self.data_type() { - List(_) => self.list()?.explode(), - FixedSizeList(..) => self.fixed_size_list()?.explode(), + DataType::List(_) => self.list()?.explode(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.explode(), dt => Err(DaftError::TypeError(format!( "explode not implemented for {}", dt @@ -20,13 +19,13 @@ impl Series { } pub fn list_count(&self, mode: CountMode) -> DaftResult { - use DataType::*; - match self.data_type() { - List(_) => self.list()?.count(mode), - FixedSizeList(..) => self.fixed_size_list()?.count(mode), - Embedding(..) | FixedShapeImage(..) => self.as_physical()?.list_count(mode), - Image(..) => { + DataType::List(_) => self.list()?.count(mode), + DataType::FixedSizeList(..) => self.fixed_size_list()?.count(mode), + DataType::Embedding(..) | DataType::FixedShapeImage(..) => { + self.as_physical()?.list_count(mode) + } + DataType::Image(..) => { let struct_array = self.as_physical()?; let data_array = struct_array.struct_()?.children[0].list().unwrap(); let offsets = data_array.offsets(); diff --git a/src/daft-core/src/series/ops/log.rs b/src/daft-core/src/series/ops/log.rs index 47c06b9dca..cb22175df8 100644 --- a/src/daft-core/src/series/ops/log.rs +++ b/src/daft-core/src/series/ops/log.rs @@ -1,18 +1,25 @@ use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftError; use common_error::DaftResult; + impl Series { pub fn log2(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { let s = self.cast(&DataType::Float64)?; Ok(s.f64()?.log2()?.into_series()) } - Float32 => Ok(self.f32()?.log2()?.into_series()), - Float64 => Ok(self.f64()?.log2()?.into_series()), + DataType::Float32 => Ok(self.f32()?.log2()?.into_series()), + DataType::Float64 => Ok(self.f64()?.log2()?.into_series()), dt => Err(DaftError::TypeError(format!( "log2 not implemented for {}", dt @@ -21,15 +28,20 @@ impl Series { } pub fn log10(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { let s = self.cast(&DataType::Float64)?; Ok(s.f64()?.log10()?.into_series()) } - Float32 => Ok(self.f32()?.log10()?.into_series()), - Float64 => Ok(self.f64()?.log10()?.into_series()), + DataType::Float32 => Ok(self.f32()?.log10()?.into_series()), + DataType::Float64 => Ok(self.f64()?.log10()?.into_series()), dt => Err(DaftError::TypeError(format!( "log10 not implemented for {}", dt @@ -38,15 +50,20 @@ impl Series { } pub fn log(&self, base: f64) -> DaftResult { - use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { let s = self.cast(&DataType::Float64)?; Ok(s.f64()?.log(base)?.into_series()) } - Float32 => Ok(self.f32()?.log(base as f32)?.into_series()), - Float64 => Ok(self.f64()?.log(base)?.into_series()), + DataType::Float32 => Ok(self.f32()?.log(base as f32)?.into_series()), + DataType::Float64 => Ok(self.f64()?.log(base)?.into_series()), dt => Err(DaftError::TypeError(format!( "log not implemented for {}", dt @@ -56,14 +73,20 @@ impl Series { pub fn ln(&self) -> DaftResult { use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { let s = self.cast(&DataType::Float64)?; Ok(s.f64()?.ln()?.into_series()) } - Float32 => Ok(self.f32()?.ln()?.into_series()), - Float64 => Ok(self.f64()?.ln()?.into_series()), + DataType::Float32 => Ok(self.f32()?.ln()?.into_series()), + DataType::Float64 => Ok(self.f64()?.ln()?.into_series()), dt => Err(DaftError::TypeError(format!( "ln not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/round.rs b/src/daft-core/src/series/ops/round.rs index a924bfb172..a784e29638 100644 --- a/src/daft-core/src/series/ops/round.rs +++ b/src/daft-core/src/series/ops/round.rs @@ -1,16 +1,22 @@ use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftError; use common_error::DaftResult; impl Series { pub fn round(&self, decimal: i32) -> DaftResult { - use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()), - Float32 => Ok(self.f32().unwrap().round(decimal)?.into_series()), - Float64 => Ok(self.f64().unwrap().round(decimal)?.into_series()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(self.clone()), + DataType::Float32 => Ok(self.f32().unwrap().round(decimal)?.into_series()), + DataType::Float64 => Ok(self.f64().unwrap().round(decimal)?.into_series()), dt => Err(DaftError::TypeError(format!( "round not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/sign.rs b/src/daft-core/src/series/ops/sign.rs index e12de0726b..6c93183a4a 100644 --- a/src/daft-core/src/series/ops/sign.rs +++ b/src/daft-core/src/series/ops/sign.rs @@ -1,23 +1,22 @@ use crate::datatypes::DataType; +use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftError; use common_error::DaftResult; impl Series { pub fn sign(&self) -> DaftResult { - use crate::series::array_impl::IntoSeries; - use DataType::*; match self.data_type() { - UInt8 => Ok(self.u8().unwrap().sign_unsigned()?.into_series()), - UInt16 => Ok(self.u16().unwrap().sign_unsigned()?.into_series()), - UInt32 => Ok(self.u32().unwrap().sign_unsigned()?.into_series()), - UInt64 => Ok(self.u64().unwrap().sign_unsigned()?.into_series()), - Int8 => Ok(self.i8().unwrap().sign()?.into_series()), - Int16 => Ok(self.i16().unwrap().sign()?.into_series()), - Int32 => Ok(self.i32().unwrap().sign()?.into_series()), - Int64 => Ok(self.i64().unwrap().sign()?.into_series()), - Float32 => Ok(self.f32().unwrap().sign()?.into_series()), - Float64 => Ok(self.f64().unwrap().sign()?.into_series()), + DataType::UInt8 => Ok(self.u8().unwrap().sign_unsigned()?.into_series()), + DataType::UInt16 => Ok(self.u16().unwrap().sign_unsigned()?.into_series()), + DataType::UInt32 => Ok(self.u32().unwrap().sign_unsigned()?.into_series()), + DataType::UInt64 => Ok(self.u64().unwrap().sign_unsigned()?.into_series()), + DataType::Int8 => Ok(self.i8().unwrap().sign()?.into_series()), + DataType::Int16 => Ok(self.i16().unwrap().sign()?.into_series()), + DataType::Int32 => Ok(self.i32().unwrap().sign()?.into_series()), + DataType::Int64 => Ok(self.i64().unwrap().sign()?.into_series()), + DataType::Float32 => Ok(self.f32().unwrap().sign()?.into_series()), + DataType::Float64 => Ok(self.f64().unwrap().sign()?.into_series()), dt => Err(DaftError::TypeError(format!( "sign not implemented for {}", dt diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index 7339dc7a15..9a550e11c2 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -3,17 +3,16 @@ use std::{borrow::Cow, sync::Arc}; use arrow2::offset::OffsetsBuffer; use serde::{de::Visitor, Deserializer}; +use crate::datatypes::*; + use crate::{ array::{ ops::{as_arrow::AsArrow, full::FullNull}, ListArray, StructArray, }, - datatypes::{ - logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, - }, - DataType, + datatypes::logical::{ + DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, + FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, series::{IntoSeries, Series}, with_match_daft_types, @@ -71,92 +70,90 @@ impl<'d> serde::Deserialize<'d> for Series { return Err(serde::de::Error::missing_field("values")); } let field = field.ok_or_else(|| serde::de::Error::missing_field("name"))?; - use crate::datatypes::*; - use DataType::*; match &field.dtype { - Null => Ok(NullArray::full_null( + DataType::Null => Ok(NullArray::full_null( &field.name, &field.dtype, map.next_value::()?, ) .into_series()), - Boolean => Ok(BooleanArray::from(( + DataType::Boolean => Ok(BooleanArray::from(( field.name.as_str(), map.next_value::>>()?.as_slice(), )) .into_series()), - Int8 => Ok(Int8Array::from_iter( + DataType::Int8 => Ok(Int8Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Int16 => Ok(Int16Array::from_iter( + DataType::Int16 => Ok(Int16Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Int32 => Ok(Int32Array::from_iter( + DataType::Int32 => Ok(Int32Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Int64 => Ok(Int64Array::from_iter( + DataType::Int64 => Ok(Int64Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Int128 => Ok(Int128Array::from_iter( + DataType::Int128 => Ok(Int128Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - UInt8 => Ok(UInt8Array::from_iter( + DataType::UInt8 => Ok(UInt8Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - UInt16 => Ok(UInt16Array::from_iter( + DataType::UInt16 => Ok(UInt16Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - UInt32 => Ok(UInt32Array::from_iter( + DataType::UInt32 => Ok(UInt32Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - UInt64 => Ok(UInt64Array::from_iter( + DataType::UInt64 => Ok(UInt64Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Float32 => Ok(Float32Array::from_iter( + DataType::Float32 => Ok(Float32Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Float64 => Ok(Float64Array::from_iter( + DataType::Float64 => Ok(Float64Array::from_iter( field.name.as_str(), map.next_value::>>()?.into_iter(), ) .into_series()), - Utf8 => Ok(Utf8Array::from_iter( + DataType::Utf8 => Ok(Utf8Array::from_iter( field.name.as_str(), map.next_value::>>>()?.into_iter(), ) .into_series()), - Binary => Ok(BinaryArray::from_iter( + DataType::Binary => Ok(BinaryArray::from_iter( field.name.as_str(), map.next_value::>>>()?.into_iter(), ) .into_series()), - FixedSizeBinary(size) => Ok(FixedSizeBinaryArray::from_iter( + DataType::FixedSizeBinary(size) => Ok(FixedSizeBinaryArray::from_iter( field.name.as_str(), map.next_value::>>>()?.into_iter(), *size, ) .into_series()), - Extension(..) => { + DataType::Extension(..) => { let physical = map.next_value::()?; let physical = physical.to_arrow(); let ext_array = physical.to_type(field.dtype.to_arrow().unwrap()); @@ -164,7 +161,7 @@ impl<'d> serde::Deserialize<'d> for Series { .unwrap() .into_series()) } - Map(..) => { + DataType::Map(..) => { let physical = map.next_value::()?; Ok(MapArray::new( Arc::new(field), @@ -172,7 +169,7 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } - Struct(..) => { + DataType::Struct(..) => { let mut all_series = map.next_value::>>()?; let validity = all_series .pop() @@ -185,7 +182,7 @@ impl<'d> serde::Deserialize<'d> for Series { let validity = validity.map(|v| v.bool().unwrap().as_bitmap().clone()); Ok(StructArray::new(Arc::new(field), children, validity).into_series()) } - List(..) => { + DataType::List(..) => { let mut all_series = map.next_value::>>()?; let validity = all_series .pop() @@ -206,7 +203,7 @@ impl<'d> serde::Deserialize<'d> for Series { .unwrap(); Ok(ListArray::new(field, flat_child, offsets, validity).into_series()) } - FixedSizeList(..) => { + DataType::FixedSizeList(..) => { let mut all_series = map.next_value::>>()?; let validity = all_series .pop() @@ -219,7 +216,7 @@ impl<'d> serde::Deserialize<'d> for Series { let validity = validity.map(|v| v.bool().unwrap().as_bitmap().clone()); Ok(FixedSizeListArray::new(field, flat_child, validity).into_series()) } - Decimal128(..) => { + DataType::Decimal128(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok(Decimal128Array::new( @@ -228,7 +225,7 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } - Timestamp(..) => { + DataType::Timestamp(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok(TimestampArray::new( @@ -237,7 +234,7 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } - Date => { + DataType::Date => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok( @@ -245,7 +242,7 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series(), ) } - Time(..) => { + DataType::Time(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok( @@ -253,7 +250,7 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series(), ) } - Duration(..) => { + DataType::Duration(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok( @@ -264,7 +261,7 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series(), ) } - Embedding(..) => { + DataType::Embedding(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok(EmbeddingArray::new( @@ -273,7 +270,7 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } - Image(..) => { + DataType::Image(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok( @@ -281,7 +278,7 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series(), ) } - FixedShapeImage(..) => { + DataType::FixedShapeImage(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok(FixedShapeImageArray::new( @@ -290,7 +287,7 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } - FixedShapeTensor(..) => { + DataType::FixedShapeTensor(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok(FixedShapeTensorArray::new( @@ -299,7 +296,7 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } - Tensor(..) => { + DataType::Tensor(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; Ok( @@ -307,10 +304,11 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series(), ) } - Python => { + #[cfg(feature = "python")] + DataType::Python => { panic!("python deserialization not implemented for rust Serde"); } - Unknown => { + DataType::Unknown => { panic!("Unable to deserialize Unknown DataType"); } } diff --git a/src/daft-core/src/utils/supertype.rs b/src/daft-core/src/utils/supertype.rs index fcf5f07010..3133be89ad 100644 --- a/src/daft-core/src/utils/supertype.rs +++ b/src/daft-core/src/utils/supertype.rs @@ -6,10 +6,9 @@ use common_error::DaftResult; // TODO: Deprecate this logic soon! fn get_time_units(tu_l: &TimeUnit, tu_r: &TimeUnit) -> TimeUnit { - use TimeUnit::*; match (tu_l, tu_r) { - (Nanoseconds, Microseconds) => Microseconds, - (_, Milliseconds) => Milliseconds, + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => TimeUnit::Microseconds, + (_, TimeUnit::Milliseconds) => TimeUnit::Milliseconds, _ => *tu_l, } } @@ -25,155 +24,153 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> DaftResult { pub fn get_supertype(l: &DataType, r: &DataType) -> Option { fn inner(l: &DataType, r: &DataType) -> Option { - use DataType::*; - if l == r { return Some(l.clone()); } match (l, r) { #[cfg(feature = "python")] - // The supertype of anything and Python is Python. - (_, Python) => Some(Python), - - (Int8, Boolean) => Some(Int8), - (Int8, Int16) => Some(Int16), - (Int8, Int32) => Some(Int32), - (Int8, Int64) => Some(Int64), - (Int8, UInt8) => Some(Int16), - (Int8, UInt16) => Some(Int32), - (Int8, UInt32) => Some(Int64), - (Int8, UInt64) => Some(Float64), // Follow numpy - (Int8, Float32) => Some(Float32), - (Int8, Float64) => Some(Float64), - - (Int16, Boolean) => Some(Int16), - (Int16, Int8) => Some(Int16), - (Int16, Int32) => Some(Int32), - (Int16, Int64) => Some(Int64), - (Int16, UInt8) => Some(Int16), - (Int16, UInt16) => Some(Int32), - (Int16, UInt32) => Some(Int64), - (Int16, UInt64) => Some(Float64), // Follow numpy - (Int16, Float32) => Some(Float32), - (Int16, Float64) => Some(Float64), - - (Int32, Boolean) => Some(Int32), - (Int32, Int8) => Some(Int32), - (Int32, Int16) => Some(Int32), - (Int32, Int64) => Some(Int64), - (Int32, UInt8) => Some(Int32), - (Int32, UInt16) => Some(Int32), - (Int32, UInt32) => Some(Int64), - (Int32, UInt64) => Some(Float64), // Follow numpy - (Int32, Float32) => Some(Float64), // Follow numpy - (Int32, Float64) => Some(Float64), - - (Int64, Boolean) => Some(Int64), - (Int64, Int8) => Some(Int64), - (Int64, Int16) => Some(Int64), - (Int64, Int32) => Some(Int64), - (Int64, UInt8) => Some(Int64), - (Int64, UInt16) => Some(Int64), - (Int64, UInt32) => Some(Int64), - (Int64, UInt64) => Some(Float64), // Follow numpy - (Int64, Float32) => Some(Float64), // Follow numpy - (Int64, Float64) => Some(Float64), - - (UInt16, UInt8) => Some(UInt16), - (UInt16, UInt32) => Some(UInt32), - (UInt16, UInt64) => Some(UInt64), - - (UInt8, UInt32) => Some(UInt32), - (UInt8, UInt64) => Some(UInt64), - (UInt32, UInt64) => Some(UInt64), - - (Boolean, UInt8) => Some(UInt8), - (Boolean, UInt16) => Some(UInt16), - (Boolean, UInt32) => Some(UInt32), - (Boolean, UInt64) => Some(UInt64), - - (Float32, UInt8) => Some(Float32), - (Float32, UInt16) => Some(Float32), - (Float32, UInt32) => Some(Float64), - (Float32, UInt64) => Some(Float64), - - (Float64, UInt8) => Some(Float64), - (Float64, UInt16) => Some(Float64), - (Float64, UInt32) => Some(Float64), - (Float64, UInt64) => Some(Float64), - - (Float64, Float32) => Some(Float64), - - (Date, UInt8) => Some(Int64), - (Date, UInt16) => Some(Int64), - (Date, UInt32) => Some(Int64), - (Date, UInt64) => Some(Int64), - (Date, Int8) => Some(Int32), - (Date, Int16) => Some(Int32), - (Date, Int32) => Some(Int32), - (Date, Int64) => Some(Int64), - (Date, Float32) => Some(Float32), - (Date, Float64) => Some(Float64), - (Date, Timestamp(tu, tz)) => Some(Timestamp(*tu, tz.clone())), - - (Timestamp(_, _), UInt32) => Some(Int64), - (Timestamp(_, _), UInt64) => Some(Int64), - (Timestamp(_, _), Int32) => Some(Int64), - (Timestamp(_, _), Int64) => Some(Int64), - (Timestamp(_, _), Float32) => Some(Float64), - (Timestamp(_, _), Float64) => Some(Float64), - (Timestamp(tu, tz), Date) => Some(Timestamp(*tu, tz.clone())), - - (Duration(_), UInt32) => Some(Int64), - (Duration(_), UInt64) => Some(Int64), - (Duration(_), Int32) => Some(Int64), - (Duration(_), Int64) => Some(Int64), - (Duration(_), Float32) => Some(Float64), - (Duration(_), Float64) => Some(Float64), - - (Time(_), Int32) => Some(Int64), - (Time(_), Int64) => Some(Int64), - (Time(_), Float32) => Some(Float64), - (Time(_), Float64) => Some(Float64), - - (Duration(lu), Timestamp(ru, Some(tz))) | (Timestamp(lu, Some(tz)), Duration(ru)) => { + // The supertype of anything and DataType::Python is DataType::Python. + (_, DataType::Python) => Some(DataType::Python), + + (DataType::Int8, DataType::Boolean) => Some(DataType::Int8), + (DataType::Int8, DataType::Int16) => Some(DataType::Int16), + (DataType::Int8, DataType::Int32) => Some(DataType::Int32), + (DataType::Int8, DataType::Int64) => Some(DataType::Int64), + (DataType::Int8, DataType::UInt8) => Some(DataType::Int16), + (DataType::Int8, DataType::UInt16) => Some(DataType::Int32), + (DataType::Int8, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int8, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int8, DataType::Float32) => Some(DataType::Float32), + (DataType::Int8, DataType::Float64) => Some(DataType::Float64), + + (DataType::Int16, DataType::Boolean) => Some(DataType::Int16), + (DataType::Int16, DataType::Int8) => Some(DataType::Int16), + (DataType::Int16, DataType::Int32) => Some(DataType::Int32), + (DataType::Int16, DataType::Int64) => Some(DataType::Int64), + (DataType::Int16, DataType::UInt8) => Some(DataType::Int16), + (DataType::Int16, DataType::UInt16) => Some(DataType::Int32), + (DataType::Int16, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int16, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int16, DataType::Float32) => Some(DataType::Float32), + (DataType::Int16, DataType::Float64) => Some(DataType::Float64), + + (DataType::Int32, DataType::Boolean) => Some(DataType::Int32), + (DataType::Int32, DataType::Int8) => Some(DataType::Int32), + (DataType::Int32, DataType::Int16) => Some(DataType::Int32), + (DataType::Int32, DataType::Int64) => Some(DataType::Int64), + (DataType::Int32, DataType::UInt8) => Some(DataType::Int32), + (DataType::Int32, DataType::UInt16) => Some(DataType::Int32), + (DataType::Int32, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int32, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int32, DataType::Float32) => Some(DataType::Float64), // Follow numpy + (DataType::Int32, DataType::Float64) => Some(DataType::Float64), + + (DataType::Int64, DataType::Boolean) => Some(DataType::Int64), + (DataType::Int64, DataType::Int8) => Some(DataType::Int64), + (DataType::Int64, DataType::Int16) => Some(DataType::Int64), + (DataType::Int64, DataType::Int32) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt8) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt16) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt32) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt64) => Some(DataType::Float64), // Follow numpy + (DataType::Int64, DataType::Float32) => Some(DataType::Float64), // Follow numpy + (DataType::Int64, DataType::Float64) => Some(DataType::Float64), + + (DataType::UInt16, DataType::UInt8) => Some(DataType::UInt16), + (DataType::UInt16, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt16, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt8, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt8, DataType::UInt64) => Some(DataType::UInt64), + (DataType::UInt32, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Boolean, DataType::UInt8) => Some(DataType::UInt8), + (DataType::Boolean, DataType::UInt16) => Some(DataType::UInt16), + (DataType::Boolean, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Boolean, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Float32, DataType::UInt8) => Some(DataType::Float32), + (DataType::Float32, DataType::UInt16) => Some(DataType::Float32), + (DataType::Float32, DataType::UInt32) => Some(DataType::Float64), + (DataType::Float32, DataType::UInt64) => Some(DataType::Float64), + + (DataType::Float64, DataType::UInt8) => Some(DataType::Float64), + (DataType::Float64, DataType::UInt16) => Some(DataType::Float64), + (DataType::Float64, DataType::UInt32) => Some(DataType::Float64), + (DataType::Float64, DataType::UInt64) => Some(DataType::Float64), + + (DataType::Float64, DataType::Float32) => Some(DataType::Float64), + + (DataType::Date, DataType::UInt8) => Some(DataType::Int64), + (DataType::Date, DataType::UInt16) => Some(DataType::Int64), + (DataType::Date, DataType::UInt32) => Some(DataType::Int64), + (DataType::Date, DataType::UInt64) => Some(DataType::Int64), + (DataType::Date, DataType::Int8) => Some(DataType::Int32), + (DataType::Date, DataType::Int16) => Some(DataType::Int32), + (DataType::Date, DataType::Int32) => Some(DataType::Int32), + (DataType::Date, DataType::Int64) => Some(DataType::Int64), + (DataType::Date, DataType::Float32) => Some(DataType::Float32), + (DataType::Date, DataType::Float64) => Some(DataType::Float64), + (DataType::Date, DataType::Timestamp(tu, tz)) => Some(DataType::Timestamp(*tu, tz.clone())), + + (DataType::Timestamp(_, _), DataType::UInt32) => Some(DataType::Int64), + (DataType::Timestamp(_, _), DataType::UInt64) => Some(DataType::Int64), + (DataType::Timestamp(_, _), DataType::Int32) => Some(DataType::Int64), + (DataType::Timestamp(_, _), DataType::Int64) => Some(DataType::Int64), + (DataType::Timestamp(_, _), DataType::Float32) => Some(DataType::Float64), + (DataType::Timestamp(_, _), DataType::Float64) => Some(DataType::Float64), + (DataType::Timestamp(tu, tz), DataType::Date) => Some(DataType::Timestamp(*tu, tz.clone())), + + (DataType::Duration(_), DataType::UInt32) => Some(DataType::Int64), + (DataType::Duration(_), DataType::UInt64) => Some(DataType::Int64), + (DataType::Duration(_), DataType::Int32) => Some(DataType::Int64), + (DataType::Duration(_), DataType::Int64) => Some(DataType::Int64), + (DataType::Duration(_), DataType::Float32) => Some(DataType::Float64), + (DataType::Duration(_), DataType::Float64) => Some(DataType::Float64), + + (DataType::Time(_), DataType::Int32) => Some(DataType::Int64), + (DataType::Time(_), DataType::Int64) => Some(DataType::Int64), + (DataType::Time(_), DataType::Float32) => Some(DataType::Float64), + (DataType::Time(_), DataType::Float64) => Some(DataType::Float64), + + (DataType::Duration(lu), DataType::Timestamp(ru, Some(tz))) | (DataType::Timestamp(lu, Some(tz)), DataType::Duration(ru)) => { if tz.is_empty() { - Some(Timestamp(get_time_units(lu, ru), None)) + Some(DataType::Timestamp(get_time_units(lu, ru), None)) } else { - Some(Timestamp(get_time_units(lu, ru), Some(tz.clone()))) + Some(DataType::Timestamp(get_time_units(lu, ru), Some(tz.clone()))) } } - (Duration(lu), Timestamp(ru, None)) | (Timestamp(lu, None), Duration(ru)) => { - Some(Timestamp(get_time_units(lu, ru), None)) + (DataType::Duration(lu), DataType::Timestamp(ru, None)) | (DataType::Timestamp(lu, None), DataType::Duration(ru)) => { + Some(DataType::Timestamp(get_time_units(lu, ru), None)) } - (Duration(_), Date) | (Date, Duration(_)) => Some(Date), - (Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))), + (DataType::Duration(_), DataType::Date) | (DataType::Date, DataType::Duration(_)) => Some(DataType::Date), + (DataType::Duration(lu), DataType::Duration(ru)) => Some(DataType::Duration(get_time_units(lu, ru))), // Some() timezones that are non equal // we cast from more precision to higher precision as that always fits with occasional loss of precision - (Timestamp(tu_l, Some(tz_l)), Timestamp(tu_r, Some(tz_r))) + (DataType::Timestamp(tu_l, Some(tz_l)), DataType::Timestamp(tu_r, Some(tz_r))) if !tz_l.is_empty() && !tz_r.is_empty() && tz_l != tz_r => { let tu = get_time_units(tu_l, tu_r); - Some(Timestamp(tu, Some("UTC".to_string()))) + Some(DataType::Timestamp(tu, Some("UTC".to_string()))) } // None and Some("") timezones // we cast from more precision to higher precision as that always fits with occasional loss of precision - (Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r)) if + (DataType::Timestamp(tu_l, tz_l), DataType::Timestamp(tu_r, tz_r)) if // both are none tz_l.is_none() && tz_r.is_none() // both have the same time zone || (tz_l.is_some() && (tz_l == tz_r)) => { let tu = get_time_units(tu_l, tu_r); - Some(Timestamp(tu, tz_r.clone())) + Some(DataType::Timestamp(tu, tz_r.clone())) } //TODO(sammy): add time, struct related dtypes - (Boolean, Float32) => Some(Float32), - (Boolean, Float64) => Some(Float64), - (List(inner_left_dtype), List(inner_right_dtype)) => { + (DataType::Boolean, DataType::Float32) => Some(DataType::Float32), + (DataType::Boolean, DataType::Float64) => Some(DataType::Float64), + (DataType::List(inner_left_dtype), DataType::List(inner_right_dtype)) => { let inner_st = get_supertype(inner_left_dtype.as_ref(), inner_right_dtype.as_ref())?; Some(DataType::List(Box::new(inner_st))) } @@ -194,17 +191,14 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { // } // every known type can be casted to a string except binary - (dt, Utf8) if !matches!(&dt, &Binary | &FixedSizeBinary(_)) => Some(Utf8), - (dt, Null) => Some(dt.clone()), // Drop Null Type + (dt, DataType::Utf8) if !matches!(&dt, &DataType::Binary | &DataType::FixedSizeBinary(_)) => Some(DataType::Utf8), + (dt, DataType::Null) => Some(dt.clone()), // Drop DataType::Null Type _ => None, } } - match inner(l, r) { - Some(dt) => Some(dt), - None => inner(r, l), - } + inner(l, r).or_else(|| inner(r, l)) } #[cfg(test)]