Skip to content

Commit

Permalink
[CHORE] Remove enum imports daft core (#2819)
Browse files Browse the repository at this point in the history
# Overview
- removed all instances of `use MyEnum::*;` from `daft-core`
- using that pattern has a huge chance of hard-to-catch bugs being
created in the future
  • Loading branch information
raunakab authored Sep 9, 2024
1 parent 08ca9a4 commit 7bee225
Show file tree
Hide file tree
Showing 20 changed files with 606 additions and 571 deletions.
14 changes: 10 additions & 4 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,13 @@ where
// Get the result of the Arrow Logical->Target cast.
let result_arrow_array = {
// First, get corresponding Arrow LogicalArray of source DataArray
use DataType::*;
let source_arrow_array = match source_dtype {
// Wrapped primitives
Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => {
DataType::Decimal128(..)
| DataType::Date
| DataType::Timestamp(..)
| DataType::Duration(..)
| DataType::Time(..) => {
with_match_daft_logical_primitive_types!(source_dtype, |$T| {
use arrow2::array::Array;
to_cast
Expand Down Expand Up @@ -111,11 +114,14 @@ where
// If the target type is also Logical, get the Arrow Physical.
let result_arrow_physical_array = {
if dtype.is_logical() {
use DataType::*;
let target_physical_type = dtype.to_physical().to_arrow()?;
match dtype {
// Primitive wrapper types: change the arrow2 array's type field to primitive
Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => {
DataType::Decimal128(..)
| DataType::Date
| DataType::Timestamp(..)
| DataType::Duration(..)
| DataType::Time(..) => {
with_match_daft_logical_primitive_types!(dtype, |$P| {
use arrow2::array::Array;
result_arrow_array
Expand Down
134 changes: 66 additions & 68 deletions src/daft-core/src/array/ops/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,17 @@ macro_rules! with_method_on_image_buffer {
(
$key_type:expr, $method: ident
) => {{
use DaftImageBuffer::*;

match $key_type {
L(img) => img.$method(),
LA(img) => img.$method(),
RGB(img) => img.$method(),
RGBA(img) => img.$method(),
L16(img) => img.$method(),
LA16(img) => img.$method(),
RGB16(img) => img.$method(),
RGBA16(img) => img.$method(),
RGB32F(img) => img.$method(),
RGBA32F(img) => img.$method(),
DaftImageBuffer::L(img) => img.$method(),
DaftImageBuffer::LA(img) => img.$method(),
DaftImageBuffer::RGB(img) => img.$method(),
DaftImageBuffer::RGBA(img) => img.$method(),
DaftImageBuffer::L16(img) => img.$method(),
DaftImageBuffer::LA16(img) => img.$method(),
DaftImageBuffer::RGB16(img) => img.$method(),
DaftImageBuffer::RGBA16(img) => img.$method(),
DaftImageBuffer::RGB32F(img) => img.$method(),
DaftImageBuffer::RGBA32F(img) => img.$method(),
}
}};
}
Expand Down Expand Up @@ -148,19 +146,17 @@ impl From<Wrap<ImageFormat>> for image::ImageFormat {
impl From<Wrap<ImageMode>> for image::ColorType {
fn from(image_mode: Wrap<ImageMode>) -> image::ColorType {
use image::ColorType;
use ImageMode::*;

match image_mode.0 {
L => ColorType::L8,
LA => ColorType::La8,
RGB => ColorType::Rgb8,
RGBA => ColorType::Rgba8,
L16 => ColorType::L16,
LA16 => ColorType::La16,
RGB16 => ColorType::Rgb16,
RGBA16 => ColorType::Rgba16,
RGB32F => ColorType::Rgb32F,
RGBA32F => ColorType::Rgba32F,
ImageMode::L => ColorType::L8,
ImageMode::LA => ColorType::La8,
ImageMode::RGB => ColorType::Rgb8,
ImageMode::RGBA => ColorType::Rgba8,
ImageMode::L16 => ColorType::L16,
ImageMode::LA16 => ColorType::La16,
ImageMode::RGB16 => ColorType::Rgb16,
ImageMode::RGBA16 => ColorType::Rgba16,
ImageMode::RGB32F => ColorType::Rgb32F,
ImageMode::RGBA32F => ColorType::Rgba32F,
}
}
}
Expand All @@ -170,19 +166,17 @@ impl TryFrom<image::ColorType> for Wrap<ImageMode> {

fn try_from(color: image::ColorType) -> DaftResult<Self> {
use image::ColorType;
use ImageMode::*;

Ok(Wrap(match color {
ColorType::L8 => Ok(L),
ColorType::La8 => Ok(LA),
ColorType::Rgb8 => Ok(RGB),
ColorType::Rgba8 => Ok(RGBA),
ColorType::L16 => Ok(L16),
ColorType::La16 => Ok(LA16),
ColorType::Rgb16 => Ok(RGB16),
ColorType::Rgba16 => Ok(RGBA16),
ColorType::Rgb32F => Ok(RGB32F),
ColorType::Rgba32F => Ok(RGBA32F),
ColorType::L8 => Ok(ImageMode::L),
ColorType::La8 => Ok(ImageMode::LA),
ColorType::Rgb8 => Ok(ImageMode::RGB),
ColorType::Rgba8 => Ok(ImageMode::RGBA),
ColorType::L16 => Ok(ImageMode::L16),
ColorType::La16 => Ok(ImageMode::LA16),
ColorType::Rgb16 => Ok(ImageMode::RGB16),
ColorType::Rgba16 => Ok(ImageMode::RGBA16),
ColorType::Rgb32F => Ok(ImageMode::RGB32F),
ColorType::Rgba32F => Ok(ImageMode::RGBA32F),
_ => Err(DaftError::ValueError(format!(
"Color type {:?} is not supported.",
color
Expand All @@ -201,12 +195,11 @@ impl<'a> DaftImageBuffer<'a> {
}

pub fn as_u8_slice(&'a self) -> &'a [u8] {
use DaftImageBuffer::*;
match self {
L(img) => img.as_raw(),
LA(img) => img.as_raw(),
RGB(img) => img.as_raw(),
RGBA(img) => img.as_raw(),
DaftImageBuffer::L(img) => img.as_raw(),
DaftImageBuffer::LA(img) => img.as_raw(),
DaftImageBuffer::RGB(img) => img.as_raw(),
DaftImageBuffer::RGBA(img) => img.as_raw(),
_ => unimplemented!("unimplemented {self:?}"),
}
}
Expand All @@ -216,19 +209,17 @@ impl<'a> DaftImageBuffer<'a> {
}

pub fn mode(&self) -> ImageMode {
use DaftImageBuffer::*;

match self {
L(..) => ImageMode::L,
LA(..) => ImageMode::LA,
RGB(..) => ImageMode::RGB,
RGBA(..) => ImageMode::RGBA,
L16(..) => ImageMode::L16,
LA16(..) => ImageMode::LA16,
RGB16(..) => ImageMode::RGB16,
RGBA16(..) => ImageMode::RGBA16,
RGB32F(..) => ImageMode::RGB32F,
RGBA32F(..) => ImageMode::RGBA32F,
DaftImageBuffer::L(..) => ImageMode::L,
DaftImageBuffer::LA(..) => ImageMode::LA,
DaftImageBuffer::RGB(..) => ImageMode::RGB,
DaftImageBuffer::RGBA(..) => ImageMode::RGBA,
DaftImageBuffer::L16(..) => ImageMode::L16,
DaftImageBuffer::LA16(..) => ImageMode::LA16,
DaftImageBuffer::RGB16(..) => ImageMode::RGB16,
DaftImageBuffer::RGBA16(..) => ImageMode::RGBA16,
DaftImageBuffer::RGB32F(..) => ImageMode::RGB32F,
DaftImageBuffer::RGBA32F(..) => ImageMode::RGBA32F,
}
}

Expand Down Expand Up @@ -272,24 +263,23 @@ impl<'a> DaftImageBuffer<'a> {
}

pub fn resize(&self, w: u32, h: u32) -> Self {
use DaftImageBuffer::*;
match self {
L(imgbuf) => {
DaftImageBuffer::L(imgbuf) => {
let result =
image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle);
DaftImageBuffer::L(image_buffer_vec_to_cow(result))
}
LA(imgbuf) => {
DaftImageBuffer::LA(imgbuf) => {
let result =
image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle);
DaftImageBuffer::LA(image_buffer_vec_to_cow(result))
}
RGB(imgbuf) => {
DaftImageBuffer::RGB(imgbuf) => {
let result =
image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle);
DaftImageBuffer::RGB(image_buffer_vec_to_cow(result))
}
RGBA(imgbuf) => {
DaftImageBuffer::RGBA(imgbuf) => {
let result =
image::imageops::resize(imgbuf, w, h, image::imageops::FilterType::Triangle);
DaftImageBuffer::RGBA(image_buffer_vec_to_cow(result))
Expand Down Expand Up @@ -638,11 +628,15 @@ impl ImageArray {
inputs: &[Option<DaftImageBuffer<'_>>],
image_mode: &Option<ImageMode>,
) -> DaftResult<Self> {
use DaftImageBuffer::*;
let is_all_u8 = inputs
.iter()
.filter_map(|b| b.as_ref())
.all(|b| matches!(b, L(..) | LA(..) | RGB(..) | RGBA(..)));
let is_all_u8 = inputs.iter().filter_map(|b| b.as_ref()).all(|b| {
matches!(
b,
DaftImageBuffer::L(..)
| DaftImageBuffer::LA(..)
| DaftImageBuffer::RGB(..)
| DaftImageBuffer::RGBA(..)
)
});
assert!(is_all_u8);

let mut data_ref = Vec::with_capacity(inputs.len());
Expand Down Expand Up @@ -775,11 +769,15 @@ impl FixedShapeImageArray {
height: u32,
width: u32,
) -> DaftResult<Self> {
use DaftImageBuffer::*;
let is_all_u8 = inputs
.iter()
.filter_map(|b| b.as_ref())
.all(|b| matches!(b, L(..) | LA(..) | RGB(..) | RGBA(..)));
let is_all_u8 = inputs.iter().filter_map(|b| b.as_ref()).all(|b| {
matches!(
b,
DaftImageBuffer::L(..)
| DaftImageBuffer::LA(..)
| DaftImageBuffer::RGB(..)
| DaftImageBuffer::RGBA(..)
)
});
assert!(is_all_u8);

let num_channels = image_mode.num_channels();
Expand Down
50 changes: 24 additions & 26 deletions src/daft-core/src/array/ops/trigonometry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,19 @@ pub enum TrigonometricFunction {

impl TrigonometricFunction {
pub fn fn_name(&self) -> &'static str {
use TrigonometricFunction::*;
match self {
Sin => "sin",
Cos => "cos",
Tan => "tan",
Cot => "cot",
ArcSin => "arcsin",
ArcCos => "arccos",
ArcTan => "arctan",
Radians => "radians",
Degrees => "degrees",
ArcTanh => "arctanh",
ArcCosh => "arccosh",
ArcSinh => "arcsinh",
TrigonometricFunction::Sin => "sin",
TrigonometricFunction::Cos => "cos",
TrigonometricFunction::Tan => "tan",
TrigonometricFunction::Cot => "cot",
TrigonometricFunction::ArcSin => "arcsin",
TrigonometricFunction::ArcCos => "arccos",
TrigonometricFunction::ArcTan => "arctan",
TrigonometricFunction::Radians => "radians",
TrigonometricFunction::Degrees => "degrees",
TrigonometricFunction::ArcTanh => "arctanh",
TrigonometricFunction::ArcCosh => "arccosh",
TrigonometricFunction::ArcSinh => "arcsinh",
}
}
}
Expand All @@ -50,20 +49,19 @@ where
T::Native: Float,
{
pub fn trigonometry(&self, func: &TrigonometricFunction) -> DaftResult<Self> {
use TrigonometricFunction::*;
match func {
Sin => self.apply(|v| v.sin()),
Cos => self.apply(|v| v.cos()),
Tan => self.apply(|v| v.tan()),
Cot => self.apply(|v| v.tan().powi(-1)),
ArcSin => self.apply(|v| v.asin()),
ArcCos => self.apply(|v| v.acos()),
ArcTan => self.apply(|v| v.atan()),
Radians => self.apply(|v| v.to_radians()),
Degrees => self.apply(|v| v.to_degrees()),
ArcTanh => self.apply(|v| v.atanh()),
ArcCosh => self.apply(|v| v.acosh()),
ArcSinh => self.apply(|v| v.asinh()),
TrigonometricFunction::Sin => self.apply(|v| v.sin()),
TrigonometricFunction::Cos => self.apply(|v| v.cos()),
TrigonometricFunction::Tan => self.apply(|v| v.tan()),
TrigonometricFunction::Cot => self.apply(|v| v.tan().powi(-1)),
TrigonometricFunction::ArcSin => self.apply(|v| v.asin()),
TrigonometricFunction::ArcCos => self.apply(|v| v.acos()),
TrigonometricFunction::ArcTan => self.apply(|v| v.atan()),
TrigonometricFunction::Radians => self.apply(|v| v.to_radians()),
TrigonometricFunction::Degrees => self.apply(|v| v.to_degrees()),
TrigonometricFunction::ArcTanh => self.apply(|v| v.atanh()),
TrigonometricFunction::ArcCosh => self.apply(|v| v.acosh()),
TrigonometricFunction::ArcSinh => self.apply(|v| v.asinh()),
}
}
}
Expand Down
12 changes: 4 additions & 8 deletions src/daft-core/src/count_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ impl_bincode_py_state_serialization!(CountMode);

impl CountMode {
pub fn iterator() -> std::slice::Iter<'static, CountMode> {
use CountMode::*;

static COUNT_MODES: [CountMode; 3] = [All, Valid, Null];
static COUNT_MODES: [CountMode; 3] = [CountMode::All, CountMode::Valid, CountMode::Null];
COUNT_MODES.iter()
}
}
Expand All @@ -53,12 +51,10 @@ impl FromStr for CountMode {
type Err = DaftError;

fn from_str(count_mode: &str) -> DaftResult<Self> {
use CountMode::*;

match count_mode {
"all" => Ok(All),
"valid" => Ok(Valid),
"null" => Ok(Null),
"all" => Ok(CountMode::All),
"valid" => Ok(CountMode::Valid),
"null" => Ok(CountMode::Null),
_ => Err(DaftError::TypeError(format!(
"Count mode {} is not supported; only the following modes are supported: {:?}",
count_mode,
Expand Down
16 changes: 8 additions & 8 deletions src/daft-core/src/datatypes/agg_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use super::DataType;

/// Get the data type that the sum of a column of the given data type should be casted to.
pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
match dtype {
Int8 | Int16 | Int32 | Int64 => Ok(Int64),
UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64),
Float32 => Ok(Float32),
Float64 => Ok(Float64),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(DataType::Int64),
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
Ok(DataType::UInt64)
}
DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
// 38 is the maximum precision for Decimal128, while 19 is the max increase based on 2^64 rows
Decimal128(a, b) => Ok(Decimal128(min(38, *a + 19), *b)),
DataType::Decimal128(a, b) => Ok(DataType::Decimal128(min(38, *a + 19), *b)),
other => Err(DaftError::TypeError(format!(
"Invalid argument to sum supertype: {}",
other
Expand All @@ -23,9 +24,8 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
if dtype.is_numeric() {
Ok(Float64)
Ok(DataType::Float64)
} else {
Err(DaftError::TypeError(format!(
"Invalid argument to mean supertype: {}",
Expand Down
Loading

0 comments on commit 7bee225

Please sign in to comment.