From 818c076c6c1d5e59e1d61a18b1a4389720f1b1c0 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Sun, 6 Oct 2024 11:32:58 -0700 Subject: [PATCH 01/28] Clean up code - move all test mods into their own separate files - if `blah.rs` had a submodule `tests`, then `blah/mod.rs` would contain the original code and `blah/tests.rs` would contain the tests - removed all enum imports - ran `cargo clippy --fix ...` --- .../src/{arithmetic.rs => arithmetic/mod.rs} | 23 +- src/daft-dsl/src/arithmetic/tests.rs | 16 + src/daft-dsl/src/{expr.rs => expr/mod.rs} | 405 +++++++----------- src/daft-dsl/src/expr/tests.rs | 83 ++++ src/daft-dsl/src/functions/map/mod.rs | 3 +- src/daft-dsl/src/functions/mod.rs | 18 +- .../src/functions/partitioning/mod.rs | 13 +- src/daft-dsl/src/functions/python/mod.rs | 8 +- src/daft-dsl/src/functions/python/udf.rs | 2 +- src/daft-dsl/src/functions/sketch/mod.rs | 3 +- src/daft-dsl/src/functions/struct_/mod.rs | 3 +- src/daft-dsl/src/functions/utf8/mod.rs | 57 ++- src/daft-dsl/src/{join.rs => join/mod.rs} | 34 +- src/daft-dsl/src/join/tests.rs | 27 ++ src/daft-dsl/src/lit.rs | 171 ++++---- .../{resolve_expr.rs => resolve_expr/mod.rs} | 148 +------ src/daft-dsl/src/resolve_expr/tests.rs | 141 ++++++ 17 files changed, 570 insertions(+), 585 deletions(-) rename src/daft-dsl/src/{arithmetic.rs => arithmetic/mod.rs} (57%) create mode 100644 src/daft-dsl/src/arithmetic/tests.rs rename src/daft-dsl/src/{expr.rs => expr/mod.rs} (78%) create mode 100644 src/daft-dsl/src/expr/tests.rs rename src/daft-dsl/src/{join.rs => join/mod.rs} (79%) create mode 100644 src/daft-dsl/src/join/tests.rs rename src/daft-dsl/src/{resolve_expr.rs => resolve_expr/mod.rs} (75%) create mode 100644 src/daft-dsl/src/resolve_expr/tests.rs diff --git a/src/daft-dsl/src/arithmetic.rs b/src/daft-dsl/src/arithmetic/mod.rs similarity index 57% rename from src/daft-dsl/src/arithmetic.rs rename to src/daft-dsl/src/arithmetic/mod.rs index 95faa64074..d4222fe64c 100644 --- a/src/daft-dsl/src/arithmetic.rs +++ b/src/daft-dsl/src/arithmetic/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use crate::{Expr, ExprRef, Operator}; macro_rules! impl_expr_op { @@ -21,23 +24,3 @@ impl_expr_op!(sub, Minus); impl_expr_op!(mul, Multiply); impl_expr_op!(div, TrueDivide); impl_expr_op!(rem, Modulus); - -#[cfg(test)] -mod tests { - use common_error::{DaftError, DaftResult}; - - use crate::{col, Expr}; - - #[test] - fn check_add_expr_type() -> DaftResult<()> { - let a = col("a"); - let b = col("b"); - let c = a.add(b); - match c.as_ref() { - Expr::BinaryOp { .. } => Ok(()), - other => Err(DaftError::ValueError(format!( - "expected expression to be a binary op expression, got {other:?}" - ))), - } - } -} diff --git a/src/daft-dsl/src/arithmetic/tests.rs b/src/daft-dsl/src/arithmetic/tests.rs new file mode 100644 index 0000000000..19a7c23310 --- /dev/null +++ b/src/daft-dsl/src/arithmetic/tests.rs @@ -0,0 +1,16 @@ +use common_error::{DaftError, DaftResult}; + +use crate::{col, Expr}; + +#[test] +fn check_add_expr_type() -> DaftResult<()> { + let a = col("a"); + let b = col("b"); + let c = a.add(b); + match c.as_ref() { + Expr::BinaryOp { .. } => Ok(()), + other => Err(DaftError::ValueError(format!( + "expected expression to be a binary op expression, got {other:?}" + ))), + } +} diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr/mod.rs similarity index 78% rename from src/daft-dsl/src/expr.rs rename to src/daft-dsl/src/expr/mod.rs index 48249355fc..b99813fc9d 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::{ io::{self, Write}, sync::Arc, @@ -159,36 +162,34 @@ pub fn binary_op(op: Operator, left: ExprRef, right: ExprRef) -> ExprRef { impl AggExpr { pub fn name(&self) -> &str { - use AggExpr::*; match self { - Count(expr, ..) - | Sum(expr) - | ApproxPercentile(ApproxPercentileParams { child: expr, .. }) - | ApproxCountDistinct(expr) - | ApproxSketch(expr, _) - | MergeSketch(expr, _) - | Mean(expr) - | Min(expr) - | Max(expr) - | AnyValue(expr, _) - | List(expr) - | Concat(expr) => expr.name(), - MapGroups { func: _, inputs } => inputs.first().unwrap().name(), + Self::Count(expr, ..) + | Self::Sum(expr) + | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) + | Self::ApproxCountDistinct(expr) + | Self::ApproxSketch(expr, _) + | Self::MergeSketch(expr, _) + | Self::Mean(expr) + | Self::Min(expr) + | Self::Max(expr) + | Self::AnyValue(expr, _) + | Self::List(expr) + | Self::Concat(expr) => expr.name(), + Self::MapGroups { func: _, inputs } => inputs.first().unwrap().name(), } } pub fn semantic_id(&self, schema: &Schema) -> FieldID { - use AggExpr::*; match self { - Count(expr, mode) => { + Self::Count(expr, mode) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_count({mode})")) } - Sum(expr) => { + Self::Sum(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_sum()")) } - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, force_list_output, @@ -199,122 +200,119 @@ impl AggExpr { percentiles, )) } - ApproxCountDistinct(expr) => { + Self::ApproxCountDistinct(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_approx_count_distinct()")) } - ApproxSketch(expr, sketch_type) => { + Self::ApproxSketch(expr, sketch_type) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_approx_sketch(sketch_type={sketch_type:?})" )) } - MergeSketch(expr, sketch_type) => { + Self::MergeSketch(expr, sketch_type) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_merge_sketch(sketch_type={sketch_type:?})" )) } - Mean(expr) => { + Self::Mean(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_mean()")) } - Min(expr) => { + Self::Min(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_min()")) } - Max(expr) => { + Self::Max(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_max()")) } - AnyValue(expr, ignore_nulls) => { + Self::AnyValue(expr, ignore_nulls) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_any_value(ignore_nulls={ignore_nulls})" )) } - List(expr) => { + Self::List(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_list()")) } - Concat(expr) => { + Self::Concat(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_concat()")) } - MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), + Self::MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), } } pub fn children(&self) -> Vec { - use AggExpr::*; match self { - Count(expr, ..) - | Sum(expr) - | ApproxPercentile(ApproxPercentileParams { child: expr, .. }) - | ApproxCountDistinct(expr) - | ApproxSketch(expr, _) - | MergeSketch(expr, _) - | Mean(expr) - | Min(expr) - | Max(expr) - | AnyValue(expr, _) - | List(expr) - | Concat(expr) => vec![expr.clone()], - MapGroups { func: _, inputs } => inputs.clone(), + Self::Count(expr, ..) + | Self::Sum(expr) + | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) + | Self::ApproxCountDistinct(expr) + | Self::ApproxSketch(expr, _) + | Self::MergeSketch(expr, _) + | Self::Mean(expr) + | Self::Min(expr) + | Self::Max(expr) + | Self::AnyValue(expr, _) + | Self::List(expr) + | Self::Concat(expr) => vec![expr.clone()], + Self::MapGroups { func: _, inputs } => inputs.clone(), } } - pub fn with_new_children(&self, children: Vec) -> Self { - use AggExpr::*; - - if let MapGroups { func: _, inputs } = &self { + pub fn with_new_children(&self, mut children: Vec) -> Self { + if let Self::MapGroups { func: _, inputs } = &self { assert_eq!(children.len(), inputs.len()); } else { assert_eq!(children.len(), 1); } + let mut first_child = || children.pop().unwrap(); match self { - Count(_, count_mode) => Count(children[0].clone(), *count_mode), - Sum(_) => Sum(children[0].clone()), - Mean(_) => Mean(children[0].clone()), - Min(_) => Min(children[0].clone()), - Max(_) => Max(children[0].clone()), - AnyValue(_, ignore_nulls) => AnyValue(children[0].clone(), *ignore_nulls), - List(_) => List(children[0].clone()), - Concat(_) => Concat(children[0].clone()), - MapGroups { func, inputs: _ } => MapGroups { + Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode), + Self::Sum(_) => Self::Sum(first_child()), + Self::Mean(_) => Self::Mean(first_child()), + Self::Min(_) => Self::Min(first_child()), + Self::Max(_) => Self::Max(first_child()), + Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls), + Self::List(_) => Self::List(first_child()), + Self::Concat(_) => Self::Concat(first_child()), + Self::MapGroups { func, inputs: _ } => Self::MapGroups { func: func.clone(), inputs: children, }, - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { percentiles, force_list_output, .. - }) => ApproxPercentile(ApproxPercentileParams { - child: children[0].clone(), + }) => Self::ApproxPercentile(ApproxPercentileParams { + child: first_child(), percentiles: percentiles.clone(), force_list_output: *force_list_output, }), - ApproxCountDistinct(_) => ApproxCountDistinct(children[0].clone()), - &ApproxSketch(_, sketch_type) => ApproxSketch(children[0].clone(), sketch_type), - &MergeSketch(_, sketch_type) => MergeSketch(children[0].clone(), sketch_type), + Self::ApproxCountDistinct(_) => Self::ApproxCountDistinct(first_child()), + &Self::ApproxSketch(_, sketch_type) => Self::ApproxSketch(first_child(), sketch_type), + &Self::MergeSketch(_, sketch_type) => Self::MergeSketch(first_child(), sketch_type), } } pub fn to_field(&self, schema: &Schema) -> DaftResult { - use AggExpr::*; match self { - Count(expr, ..) => { + Self::Count(expr, ..) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Sum(expr) => { + Self::Sum(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), try_sum_supertype(&field.dtype)?, )) } - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, force_list_output, @@ -337,11 +335,11 @@ impl AggExpr { }, )) } - ApproxCountDistinct(expr) => { + Self::ApproxCountDistinct(expr) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - ApproxSketch(expr, sketch_type) => { + Self::ApproxSketch(expr, sketch_type) => { let field = expr.to_field(schema)?; let dtype = match sketch_type { SketchType::DDSketch => { @@ -357,7 +355,7 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - MergeSketch(expr, sketch_type) => { + Self::MergeSketch(expr, sketch_type) => { let field = expr.to_field(schema)?; let dtype = match sketch_type { SketchType::DDSketch => { @@ -374,19 +372,19 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - Mean(expr) => { + Self::Mean(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), try_mean_supertype(&field.dtype)?, )) } - Min(expr) | Max(expr) | AnyValue(expr, _) => { + Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) } - List(expr) => expr.to_field(schema)?.to_list_field(), - Concat(expr) => { + Self::List(expr) => expr.to_field(schema)?.to_list_field(), + Self::Concat(expr) => { let field = expr.to_field(schema)?; match field.dtype { DataType::List(..) => Ok(field), @@ -399,19 +397,18 @@ impl AggExpr { ))), } } - MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), + Self::MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), } } pub fn from_name_and_child_expr(name: &str, child: ExprRef) -> DaftResult { - use AggExpr::*; match name { - "count" => Ok(Count(child, CountMode::Valid)), - "sum" => Ok(Sum(child)), - "mean" => Ok(Mean(child)), - "min" => Ok(Min(child)), - "max" => Ok(Max(child)), - "list" => Ok(List(child)), + "count" => Ok(Self::Count(child, CountMode::Valid)), + "sum" => Ok(Self::Sum(child)), + "mean" => Ok(Self::Mean(child)), + "min" => Ok(Self::Min(child)), + "max" => Ok(Self::Max(child)), + "list" => Ok(Self::List(child)), _ => Err(DaftError::ValueError(format!( "{} not a valid aggregation name", name @@ -576,57 +573,55 @@ impl Expr { } pub fn semantic_id(&self, schema: &Schema) -> FieldID { - use Expr::*; match self { // Base case - anonymous column reference. // Look up the column name in the provided schema and get its field ID. - Column(name) => FieldID::new(&**name), + Self::Column(name) => FieldID::new(&**name), // Base case - literal. - Literal(value) => FieldID::new(format!("Literal({value:?})")), + Self::Literal(value) => FieldID::new(format!("Literal({value:?})")), // Recursive cases. - Cast(expr, dtype) => { + Self::Cast(expr, dtype) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.cast({dtype})")) } - Not(expr) => { + Self::Not(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not()")) } - IsNull(expr) => { + Self::IsNull(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.is_null()")) } - NotNull(expr) => { + Self::NotNull(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not_null()")) } - FillNull(expr, fill_value) => { + Self::FillNull(expr, fill_value) => { let child_id = expr.semantic_id(schema); let fill_value_id = fill_value.semantic_id(schema); FieldID::new(format!("{child_id}.fill_null({fill_value_id})")) } - IsIn(expr, items) => { + Self::IsIn(expr, items) => { let child_id = expr.semantic_id(schema); let items_id = items.semantic_id(schema); FieldID::new(format!("{child_id}.is_in({items_id})")) } - Between(expr, lower, upper) => { + Self::Between(expr, lower, upper) => { let child_id = expr.semantic_id(schema); let lower_id = lower.semantic_id(schema); let upper_id = upper.semantic_id(schema); FieldID::new(format!("{child_id}.between({lower_id},{upper_id})")) } - Function { func, inputs } => function_semantic_id(func, inputs, schema), - BinaryOp { op, left, right } => { + Self::Function { func, inputs } => function_semantic_id(func, inputs, schema), + Self::BinaryOp { op, left, right } => { let left_id = left.semantic_id(schema); let right_id = right.semantic_id(schema); // TODO: check for symmetry here. FieldID::new(format!("({left_id} {op} {right_id})")) } - - IfElse { + Self::IfElse { if_true, if_false, predicate, @@ -636,96 +631,100 @@ impl Expr { let predicate = predicate.semantic_id(schema); FieldID::new(format!("({if_true} if {predicate} else {if_false})")) } - // Alias: ID does not change. - Alias(expr, ..) => expr.semantic_id(schema), - + Self::Alias(expr, ..) => expr.semantic_id(schema), // Agg: Separate path. - Agg(agg_expr) => agg_expr.semantic_id(schema), - ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), + Self::Agg(agg_expr) => agg_expr.semantic_id(schema), + Self::ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), } } pub fn children(&self) -> Vec { - use Expr::*; match self { // No children. - Column(..) => vec![], - Literal(..) => vec![], + Self::Column(..) => vec![], + Self::Literal(..) => vec![], // One child. - Not(expr) | IsNull(expr) | NotNull(expr) | Cast(expr, ..) | Alias(expr, ..) => { + Self::Not(expr) + | Self::IsNull(expr) + | Self::NotNull(expr) + | Self::Cast(expr, ..) + | Self::Alias(expr, ..) => { vec![expr.clone()] } - Agg(agg_expr) => agg_expr.children(), + Self::Agg(agg_expr) => agg_expr.children(), // Multiple children. - Function { inputs, .. } => inputs.clone(), - BinaryOp { left, right, .. } => { + Self::Function { inputs, .. } => inputs.clone(), + Self::BinaryOp { left, right, .. } => { vec![left.clone(), right.clone()] } - IsIn(expr, items) => vec![expr.clone(), items.clone()], - Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], - IfElse { + Self::IsIn(expr, items) => vec![expr.clone(), items.clone()], + Self::Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], + Self::IfElse { if_true, if_false, predicate, } => { vec![if_true.clone(), if_false.clone(), predicate.clone()] } - FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], - ScalarFunction(sf) => sf.inputs.clone(), + Self::FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], + Self::ScalarFunction(sf) => sf.inputs.clone(), } } pub fn with_new_children(&self, children: Vec) -> Self { - use Expr::*; match self { // no children - Column(..) | Literal(..) => { + Self::Column(..) | Self::Literal(..) => { assert!(children.is_empty(), "Should have no children"); self.clone() } // 1 child - Not(..) => Not(children.first().expect("Should have 1 child").clone()), - Alias(.., name) => Alias( + Self::Not(..) => Self::Not(children.first().expect("Should have 1 child").clone()), + Self::Alias(.., name) => Self::Alias( children.first().expect("Should have 1 child").clone(), name.clone(), ), - IsNull(..) => IsNull(children.first().expect("Should have 1 child").clone()), - NotNull(..) => NotNull(children.first().expect("Should have 1 child").clone()), - Cast(.., dtype) => Cast( + Self::IsNull(..) => { + Self::IsNull(children.first().expect("Should have 1 child").clone()) + } + Self::NotNull(..) => { + Self::NotNull(children.first().expect("Should have 1 child").clone()) + } + Self::Cast(.., dtype) => Self::Cast( children.first().expect("Should have 1 child").clone(), dtype.clone(), ), // 2 children - BinaryOp { op, .. } => BinaryOp { + Self::BinaryOp { op, .. } => Self::BinaryOp { op: *op, left: children.first().expect("Should have 1 child").clone(), right: children.get(1).expect("Should have 2 child").clone(), }, - IsIn(..) => IsIn( + Self::IsIn(..) => Self::IsIn( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), - Between(..) => Between( + Self::Between(..) => Self::Between( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), children.get(2).expect("Should have 3 child").clone(), ), - FillNull(..) => FillNull( + Self::FillNull(..) => Self::FillNull( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), // ternary - IfElse { .. } => IfElse { + Self::IfElse { .. } => Self::IfElse { if_true: children.first().expect("Should have 1 child").clone(), if_false: children.get(1).expect("Should have 2 child").clone(), predicate: children.get(2).expect("Should have 3 child").clone(), }, // N-ary - Agg(agg_expr) => Agg(agg_expr.with_new_children(children)), - Function { + Self::Agg(agg_expr) => Self::Agg(agg_expr.with_new_children(children)), + Self::Function { func, inputs: old_children, } => { @@ -733,18 +732,18 @@ impl Expr { children.len() == old_children.len(), "Should have same number of children" ); - Function { + Self::Function { func: func.clone(), inputs: children, } } - ScalarFunction(sf) => { + Self::ScalarFunction(sf) => { assert!( children.len() == sf.inputs.len(), "Should have same number of children" ); - ScalarFunction(crate::functions::ScalarFunction { + Self::ScalarFunction(crate::functions::ScalarFunction { udf: sf.udf.clone(), inputs: children, }) @@ -753,13 +752,12 @@ impl Expr { } pub fn to_field(&self, schema: &Schema) -> DaftResult { - use Expr::*; match self { - Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)), - Agg(agg_expr) => agg_expr.to_field(schema), - Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())), - Column(name) => Ok(schema.get_field(name).cloned()?), - Not(expr) => { + Self::Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)), + Self::Agg(agg_expr) => agg_expr.to_field(schema), + Self::Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())), + Self::Column(name) => Ok(schema.get_field(name).cloned()?), + Self::Not(expr) => { let child_field = expr.to_field(schema)?; match child_field.dtype { DataType::Boolean => Ok(Field::new(expr.name(), DataType::Boolean)), @@ -768,9 +766,9 @@ impl Expr { ))), } } - IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), - NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), - FillNull(expr, fill_value) => { + Self::IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), + Self::NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), + Self::FillNull(expr, fill_value) => { let expr_field = expr.to_field(schema)?; let fill_value_field = fill_value.to_field(schema)?; match try_get_supertype(&expr_field.dtype, &fill_value_field.dtype) { @@ -780,7 +778,7 @@ impl Expr { ))) } } - IsIn(left, right) => { + Self::IsIn(left, right) => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; let (result_type, _intermediate, _comp_type) = @@ -788,7 +786,7 @@ impl Expr { .membership_op(&InferDataType::from(&right_field.dtype))?; Ok(Field::new(left_field.name.as_str(), result_type)) } - Between(value, lower, upper) => { + Self::Between(value, lower, upper) => { let value_field = value.to_field(schema)?; let lower_field = lower.to_field(schema)?; let upper_field = upper.to_field(schema)?; @@ -803,11 +801,10 @@ impl Expr { .membership_op(&InferDataType::from(&upper_result_type))?; Ok(Field::new(value_field.name.as_str(), result_type)) } - Literal(value) => Ok(Field::new("literal", value.get_type())), - Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), - ScalarFunction(sf) => sf.to_field(schema), - - BinaryOp { op, left, right } => { + Self::Literal(value) => Ok(Field::new("literal", value.get_type())), + Self::Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), + Self::ScalarFunction(sf) => sf.to_field(schema), + Self::BinaryOp { op, left, right } => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; @@ -873,7 +870,7 @@ impl Expr { } } } - IfElse { + Self::IfElse { if_true, if_false, predicate, @@ -903,33 +900,32 @@ impl Expr { } pub fn name(&self) -> &str { - use Expr::*; match self { - Alias(.., name) => name.as_ref(), - Agg(agg_expr) => agg_expr.name(), - Cast(expr, ..) => expr.name(), - Column(name) => name.as_ref(), - Not(expr) => expr.name(), - IsNull(expr) => expr.name(), - NotNull(expr) => expr.name(), - FillNull(expr, ..) => expr.name(), - IsIn(expr, ..) => expr.name(), - Between(expr, ..) => expr.name(), - Literal(..) => "literal", - Function { func, inputs } => match func { + Self::Alias(.., name) => name.as_ref(), + Self::Agg(agg_expr) => agg_expr.name(), + Self::Cast(expr, ..) => expr.name(), + Self::Column(name) => name.as_ref(), + Self::Not(expr) => expr.name(), + Self::IsNull(expr) => expr.name(), + Self::NotNull(expr) => expr.name(), + Self::FillNull(expr, ..) => expr.name(), + Self::IsIn(expr, ..) => expr.name(), + Self::Between(expr, ..) => expr.name(), + Self::Literal(..) => "literal", + Self::Function { func, inputs } => match func { FunctionExpr::Struct(StructExpr::Get(name)) => name, _ => inputs.first().unwrap().name(), }, - ScalarFunction(func) => match func.name() { + Self::ScalarFunction(func) => match func.name() { "to_struct" => "struct", // FIXME: make .name() use output name from schema _ => func.inputs.first().unwrap().name(), }, - BinaryOp { + Self::BinaryOp { op: _, left, right: _, } => left.name(), - IfElse { if_true, .. } => if_true.name(), + Self::IfElse { if_true, .. } => if_true.name(), } } @@ -1119,90 +1115,3 @@ pub fn has_stateful_udf(expr: &ExprRef) -> bool { ) }) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn check_comparison_type() -> DaftResult<()> { - let x = lit(10.); - let y = lit(12); - let schema = Schema::empty(); - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Lt, - }; - assert_eq!(z.get_type(&schema)?, DataType::Boolean); - Ok(()) - } - - #[test] - fn check_alias_type() -> DaftResult<()> { - let a = col("a"); - let b = a.alias("b"); - match b.as_ref() { - Expr::Alias(..) => Ok(()), - other => Err(common_error::DaftError::ValueError(format!( - "expected expression to be a alias, got {other:?}" - ))), - } - } - - #[test] - fn check_arithmetic_type() -> DaftResult<()> { - let x = lit(10.); - let y = lit(12); - let schema = Schema::empty(); - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - let x = lit(10.); - let y = lit(12); - - let z = Expr::BinaryOp { - left: y, - right: x, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - Ok(()) - } - - #[test] - fn check_arithmetic_type_with_columns() -> DaftResult<()> { - let x = col("x"); - let y = col("y"); - let schema = Schema::new(vec![ - Field::new("x", DataType::Float64), - Field::new("y", DataType::Int64), - ])?; - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - let x = col("x"); - let y = col("y"); - - let z = Expr::BinaryOp { - left: y, - right: x, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - Ok(()) - } -} diff --git a/src/daft-dsl/src/expr/tests.rs b/src/daft-dsl/src/expr/tests.rs new file mode 100644 index 0000000000..aff680c5d3 --- /dev/null +++ b/src/daft-dsl/src/expr/tests.rs @@ -0,0 +1,83 @@ +use super::*; + +#[test] +fn check_comparison_type() -> DaftResult<()> { + let x = lit(10.); + let y = lit(12); + let schema = Schema::empty(); + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Lt, + }; + assert_eq!(z.get_type(&schema)?, DataType::Boolean); + Ok(()) +} + +#[test] +fn check_alias_type() -> DaftResult<()> { + let a = col("a"); + let b = a.alias("b"); + match b.as_ref() { + Expr::Alias(..) => Ok(()), + other => Err(common_error::DaftError::ValueError(format!( + "expected expression to be a alias, got {other:?}" + ))), + } +} + +#[test] +fn check_arithmetic_type() -> DaftResult<()> { + let x = lit(10.); + let y = lit(12); + let schema = Schema::empty(); + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + let x = lit(10.); + let y = lit(12); + + let z = Expr::BinaryOp { + left: y, + right: x, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + Ok(()) +} + +#[test] +fn check_arithmetic_type_with_columns() -> DaftResult<()> { + let x = col("x"); + let y = col("y"); + let schema = Schema::new(vec![ + Field::new("x", DataType::Float64), + Field::new("y", DataType::Int64), + ])?; + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + let x = col("x"); + let y = col("y"); + + let z = Expr::BinaryOp { + left: y, + right: x, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + Ok(()) +} diff --git a/src/daft-dsl/src/functions/map/mod.rs b/src/daft-dsl/src/functions/map/mod.rs index 979a6ccd1e..083e99e7db 100644 --- a/src/daft-dsl/src/functions/map/mod.rs +++ b/src/daft-dsl/src/functions/map/mod.rs @@ -14,9 +14,8 @@ pub enum MapExpr { impl MapExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use MapExpr::*; match self { - Get => &GetEvaluator {}, + Self::Get => &GetEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 0386d7c54c..6f0b162422 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,5 +1,6 @@ pub mod map; pub mod partitioning; +pub mod python; pub mod scalar; pub mod sketch; pub mod struct_; @@ -12,6 +13,7 @@ use std::{ use common_error::DaftResult; use daft_core::prelude::*; +use python::PythonUDF; pub use scalar::*; use serde::{Deserialize, Serialize}; @@ -21,9 +23,6 @@ use self::{ }; use crate::{Expr, ExprRef, Operator}; -pub mod python; -use python::PythonUDF; - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { Utf8(Utf8Expr), @@ -48,14 +47,13 @@ pub trait FunctionEvaluator { impl FunctionExpr { #[inline] fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use FunctionExpr::*; match self { - Utf8(expr) => expr.get_evaluator(), - Map(expr) => expr.get_evaluator(), - Sketch(expr) => expr.get_evaluator(), - Struct(expr) => expr.get_evaluator(), - Python(expr) => expr.get_evaluator(), - Partitioning(expr) => expr.get_evaluator(), + Self::Utf8(expr) => expr.get_evaluator(), + Self::Map(expr) => expr.get_evaluator(), + Self::Sketch(expr) => expr.get_evaluator(), + Self::Struct(expr) => expr.get_evaluator(), + Self::Python(expr) => expr.get_evaluator(), + Self::Partitioning(expr) => expr.get_evaluator(), } } } diff --git a/src/daft-dsl/src/functions/partitioning/mod.rs b/src/daft-dsl/src/functions/partitioning/mod.rs index 9f37414e18..ead6ed91f8 100644 --- a/src/daft-dsl/src/functions/partitioning/mod.rs +++ b/src/daft-dsl/src/functions/partitioning/mod.rs @@ -24,14 +24,13 @@ pub enum PartitioningExpr { impl PartitioningExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use PartitioningExpr::*; match self { - Years => &YearsEvaluator {}, - Months => &MonthsEvaluator {}, - Days => &DaysEvaluator {}, - Hours => &HoursEvaluator {}, - IcebergBucket(..) => &IcebergBucketEvaluator {}, - IcebergTruncate(..) => &IcebergTruncateEvaluator {}, + Self::Years => &YearsEvaluator {}, + Self::Months => &MonthsEvaluator {}, + Self::Days => &DaysEvaluator {}, + Self::Hours => &HoursEvaluator {}, + Self::IcebergBucket(..) => &IcebergBucketEvaluator {}, + Self::IcebergTruncate(..) => &IcebergTruncateEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index adbb2830e7..c09adb0439 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -2,9 +2,13 @@ mod runtime_py_object; mod udf; mod udf_runtime_binding; -use std::{collections::HashMap, sync::Arc}; +#[cfg(feature = "python")] +use std::collections::HashMap; +use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +#[cfg(feature = "python")] +use common_error::DaftError; +use common_error::DaftResult; use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; use daft_core::datatypes::DataType; diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 7b9da47bf7..100bd06566 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -1,5 +1,5 @@ use common_error::{DaftError, DaftResult}; -use daft_core::{datatypes::DataType, prelude::*}; +use daft_core::prelude::*; #[cfg(feature = "python")] use pyo3::{ types::{PyAnyMethods, PyModule}, diff --git a/src/daft-dsl/src/functions/sketch/mod.rs b/src/daft-dsl/src/functions/sketch/mod.rs index 87c5df6f6d..4e2fdc566b 100644 --- a/src/daft-dsl/src/functions/sketch/mod.rs +++ b/src/daft-dsl/src/functions/sketch/mod.rs @@ -30,9 +30,8 @@ pub enum SketchExpr { impl SketchExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use SketchExpr::*; match self { - Percentile { .. } => &PercentileEvaluator {}, + Self::Percentile { .. } => &PercentileEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/struct_/mod.rs b/src/daft-dsl/src/functions/struct_/mod.rs index c842c45c64..7d8d192d25 100644 --- a/src/daft-dsl/src/functions/struct_/mod.rs +++ b/src/daft-dsl/src/functions/struct_/mod.rs @@ -14,9 +14,8 @@ pub enum StructExpr { impl StructExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use StructExpr::*; match self { - Get(_) => &GetEvaluator {}, + Self::Get(_) => &GetEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index cb3a07aca1..7a795250ff 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -95,36 +95,35 @@ pub enum Utf8Expr { impl Utf8Expr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use Utf8Expr::*; match self { - EndsWith => &EndswithEvaluator {}, - StartsWith => &StartswithEvaluator {}, - Contains => &ContainsEvaluator {}, - Split(_) => &SplitEvaluator {}, - Match => &MatchEvaluator {}, - Extract(_) => &ExtractEvaluator {}, - ExtractAll(_) => &ExtractAllEvaluator {}, - Replace(_) => &ReplaceEvaluator {}, - Length => &LengthEvaluator {}, - LengthBytes => &LengthBytesEvaluator {}, - Lower => &LowerEvaluator {}, - Upper => &UpperEvaluator {}, - Lstrip => &LstripEvaluator {}, - Rstrip => &RstripEvaluator {}, - Reverse => &ReverseEvaluator {}, - Capitalize => &CapitalizeEvaluator {}, - Left => &LeftEvaluator {}, - Right => &RightEvaluator {}, - Find => &FindEvaluator {}, - Rpad => &RpadEvaluator {}, - Lpad => &LpadEvaluator {}, - Repeat => &RepeatEvaluator {}, - Like => &LikeEvaluator {}, - Ilike => &IlikeEvaluator {}, - Substr => &SubstrEvaluator {}, - ToDate(_) => &ToDateEvaluator {}, - ToDatetime(_, _) => &ToDatetimeEvaluator {}, - Normalize(_) => &NormalizeEvaluator {}, + Self::EndsWith => &EndswithEvaluator {}, + Self::StartsWith => &StartswithEvaluator {}, + Self::Contains => &ContainsEvaluator {}, + Self::Split(_) => &SplitEvaluator {}, + Self::Match => &MatchEvaluator {}, + Self::Extract(_) => &ExtractEvaluator {}, + Self::ExtractAll(_) => &ExtractAllEvaluator {}, + Self::Replace(_) => &ReplaceEvaluator {}, + Self::Length => &LengthEvaluator {}, + Self::LengthBytes => &LengthBytesEvaluator {}, + Self::Lower => &LowerEvaluator {}, + Self::Upper => &UpperEvaluator {}, + Self::Lstrip => &LstripEvaluator {}, + Self::Rstrip => &RstripEvaluator {}, + Self::Reverse => &ReverseEvaluator {}, + Self::Capitalize => &CapitalizeEvaluator {}, + Self::Left => &LeftEvaluator {}, + Self::Right => &RightEvaluator {}, + Self::Find => &FindEvaluator {}, + Self::Rpad => &RpadEvaluator {}, + Self::Lpad => &LpadEvaluator {}, + Self::Repeat => &RepeatEvaluator {}, + Self::Like => &LikeEvaluator {}, + Self::Ilike => &IlikeEvaluator {}, + Self::Substr => &SubstrEvaluator {}, + Self::ToDate(_) => &ToDateEvaluator {}, + Self::ToDatetime(_, _) => &ToDatetimeEvaluator {}, + Self::Normalize(_) => &NormalizeEvaluator {}, } } } diff --git a/src/daft-dsl/src/join.rs b/src/daft-dsl/src/join/mod.rs similarity index 79% rename from src/daft-dsl/src/join.rs rename to src/daft-dsl/src/join/mod.rs index 2f1cf96cb2..1de29b995e 100644 --- a/src/daft-dsl/src/join.rs +++ b/src/daft-dsl/src/join/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::sync::Arc; use common_error::{DaftError, DaftResult}; @@ -79,34 +82,3 @@ pub fn infer_join_schema( Ok(Schema::new(fields)?.into()) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::col; - - #[test] - fn test_get_common_join_keys() { - let left_on: &[ExprRef] = &[ - col("a"), - col("b_left"), - col("c").alias("c_new"), - col("d").alias("d_new"), - col("e").add(col("f")), - ]; - - let right_on: &[ExprRef] = &[ - col("a"), - col("b_right"), - col("c"), - col("d").alias("d_new"), - col("e"), - ]; - - let common_join_keys = get_common_join_keys(left_on, right_on) - .map(|k| k.to_string()) - .collect::>(); - - assert_eq!(common_join_keys, vec!["a"]); - } -} diff --git a/src/daft-dsl/src/join/tests.rs b/src/daft-dsl/src/join/tests.rs new file mode 100644 index 0000000000..52d58a76c0 --- /dev/null +++ b/src/daft-dsl/src/join/tests.rs @@ -0,0 +1,27 @@ +use super::*; +use crate::col; + +#[test] +fn test_get_common_join_keys() { + let left_on: &[ExprRef] = &[ + col("a"), + col("b_left"), + col("c").alias("c_new"), + col("d").alias("d_new"), + col("e").add(col("f")), + ]; + + let right_on: &[ExprRef] = &[ + col("a"), + col("b_right"), + col("c"), + col("d").alias("d_new"), + col("e"), + ]; + + let common_join_keys = get_common_join_keys(left_on, right_on) + .map(|k| k.to_string()) + .collect::>(); + + assert_eq!(common_join_keys, vec!["a"]); +} diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 55888d73f8..ac257532e1 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -77,36 +77,34 @@ impl Eq for LiteralValue {} impl Hash for LiteralValue { fn hash(&self, state: &mut H) { - use LiteralValue::*; - match self { // Stable hash for Null variant. - Null => 1.hash(state), - Boolean(bool) => bool.hash(state), - Utf8(s) => s.hash(state), - Binary(arr) => arr.hash(state), - Int32(n) => n.hash(state), - UInt32(n) => n.hash(state), - Int64(n) => n.hash(state), - UInt64(n) => n.hash(state), - Date(n) => n.hash(state), - Time(n, tu) => { + Self::Null => 1.hash(state), + Self::Boolean(bool) => bool.hash(state), + Self::Utf8(s) => s.hash(state), + Self::Binary(arr) => arr.hash(state), + Self::Int32(n) => n.hash(state), + Self::UInt32(n) => n.hash(state), + Self::Int64(n) => n.hash(state), + Self::UInt64(n) => n.hash(state), + Self::Date(n) => n.hash(state), + Self::Time(n, tu) => { n.hash(state); tu.hash(state); } - Timestamp(n, tu, tz) => { + Self::Timestamp(n, tu, tz) => { n.hash(state); tu.hash(state); tz.hash(state); } // Wrap float64 in hashable newtype. - Float64(n) => FloatWrapper(*n).hash(state), - Decimal(n, precision, scale) => { + Self::Float64(n) => FloatWrapper(*n).hash(state), + Self::Decimal(n, precision, scale) => { n.hash(state); precision.hash(state); scale.hash(state); } - Series(series) => { + Self::Series(series) => { let hash_result = series.hash(None); match hash_result { Ok(hash) => hash.into_iter().for_each(|i| i.hash(state)), @@ -114,8 +112,8 @@ impl Hash for LiteralValue { } } #[cfg(feature = "python")] - Python(py_obj) => py_obj.hash(state), - Struct(entries) => { + Self::Python(py_obj) => py_obj.hash(state), + Self::Struct(entries) => { entries.iter().for_each(|(v, f)| { v.hash(state); f.hash(state); @@ -128,31 +126,32 @@ impl Hash for LiteralValue { impl Display for LiteralValue { // `f` is a buffer, and this method must write the formatted string into it fn fmt(&self, f: &mut Formatter) -> Result { - use LiteralValue::*; match self { - Null => write!(f, "Null"), - Boolean(val) => write!(f, "{val}"), - Utf8(val) => write!(f, "\"{val}\""), - Binary(val) => write!(f, "Binary[{}]", val.len()), - Int32(val) => write!(f, "{val}"), - UInt32(val) => write!(f, "{val}"), - Int64(val) => write!(f, "{val}"), - UInt64(val) => write!(f, "{val}"), - Date(val) => write!(f, "{}", display_date32(*val)), - Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), - Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), - Float64(val) => write!(f, "{val:.1}"), - Decimal(val, precision, scale) => { + Self::Null => write!(f, "Null"), + Self::Boolean(val) => write!(f, "{val}"), + Self::Utf8(val) => write!(f, "\"{val}\""), + Self::Binary(val) => write!(f, "Binary[{}]", val.len()), + Self::Int32(val) => write!(f, "{val}"), + Self::UInt32(val) => write!(f, "{val}"), + Self::Int64(val) => write!(f, "{val}"), + Self::UInt64(val) => write!(f, "{val}"), + Self::Date(val) => write!(f, "{}", display_date32(*val)), + Self::Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), + Self::Timestamp(val, tu, tz) => { + write!(f, "{}", display_timestamp(*val, tu, tz)) + } + Self::Float64(val) => write!(f, "{val:.1}"), + Self::Decimal(val, precision, scale) => { write!(f, "{}", display_decimal128(*val, *precision, *scale)) } - Series(series) => write!(f, "{}", display_series_literal(series)), + Self::Series(series) => write!(f, "{}", display_series_literal(series)), #[cfg(feature = "python")] - Python(pyobj) => write!(f, "PyObject({})", { + Self::Python(pyobj) => write!(f, "PyObject({})", { use pyo3::prelude::*; Python::with_gil(|py| pyobj.0.call_method0(py, pyo3::intern!(py, "__str__"))) .unwrap() }), - Struct(entries) => { + Self::Struct(entries) => { write!(f, "Struct(")?; for (i, (field, v)) in entries.iter().enumerate() { if i > 0 { @@ -168,101 +167,101 @@ impl Display for LiteralValue { impl LiteralValue { pub fn get_type(&self) -> DataType { - use LiteralValue::*; match self { - Null => DataType::Null, - Boolean(_) => DataType::Boolean, - Utf8(_) => DataType::Utf8, - Binary(_) => DataType::Binary, - Int32(_) => DataType::Int32, - UInt32(_) => DataType::UInt32, - Int64(_) => DataType::Int64, - UInt64(_) => DataType::UInt64, - Date(_) => DataType::Date, - Time(_, tu) => DataType::Time(*tu), - Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), - Float64(_) => DataType::Float64, - Decimal(_, precision, scale) => { + Self::Null => DataType::Null, + Self::Boolean(_) => DataType::Boolean, + Self::Utf8(_) => DataType::Utf8, + Self::Binary(_) => DataType::Binary, + Self::Int32(_) => DataType::Int32, + Self::UInt32(_) => DataType::UInt32, + Self::Int64(_) => DataType::Int64, + Self::UInt64(_) => DataType::UInt64, + Self::Date(_) => DataType::Date, + Self::Time(_, tu) => DataType::Time(*tu), + Self::Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), + Self::Float64(_) => DataType::Float64, + Self::Decimal(_, precision, scale) => { DataType::Decimal128(*precision as usize, *scale as usize) } - Series(series) => series.data_type().clone(), + Self::Series(series) => series.data_type().clone(), #[cfg(feature = "python")] - Python(_) => DataType::Python, - Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), + Self::Python(_) => DataType::Python, + Self::Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), } } pub fn to_series(&self) -> Series { - use LiteralValue::*; - let result = match self { - Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), - Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), - Utf8(val) => Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series(), - Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), - Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), - UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), - Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), - UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), - Date(val) => { + match self { + Self::Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), + Self::Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), + Self::Utf8(val) => { + Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series() + } + Self::Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), + Self::Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), + Self::UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), + Self::Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), + Self::UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Date(val) => { let physical = Int32Array::from(("literal", [*val].as_slice())); DateArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Time(val, ..) => { + Self::Time(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimeArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Timestamp(val, ..) => { + Self::Timestamp(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), - Decimal(val, ..) => { + Self::Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Decimal(val, ..) => { let physical = Int128Array::from(("literal", [*val].as_slice())); Decimal128Array::new(Field::new("literal", self.get_type()), physical).into_series() } - Series(series) => series.clone().rename("literal"), + Self::Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] - Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), - Struct(entries) => { + Self::Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), + Self::Struct(entries) => { let struct_dtype = DataType::Struct(entries.keys().cloned().collect()); let struct_field = Field::new("literal", struct_dtype); let values = entries.values().map(|v| v.to_series()).collect(); StructArray::new(struct_field, values, None).into_series() } - }; - result + } } pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { - use LiteralValue::*; let display_sql_err = Err(io::Error::new( io::ErrorKind::Other, "Unsupported literal for SQL translation", )); match self { - Null => write!(buffer, "NULL"), - Boolean(v) => write!(buffer, "{}", v), - Int32(val) => write!(buffer, "{}", val), - UInt32(val) => write!(buffer, "{}", val), - Int64(val) => write!(buffer, "{}", val), - UInt64(val) => write!(buffer, "{}", val), - Float64(val) => write!(buffer, "{}", val), - Utf8(val) => write!(buffer, "'{}'", val), - Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), + Self::Null => write!(buffer, "NULL"), + Self::Boolean(v) => write!(buffer, "{}", v), + Self::Int32(val) => write!(buffer, "{}", val), + Self::UInt32(val) => write!(buffer, "{}", val), + Self::Int64(val) => write!(buffer, "{}", val), + Self::UInt64(val) => write!(buffer, "{}", val), + Self::Float64(val) => write!(buffer, "{}", val), + Self::Utf8(val) => write!(buffer, "'{}'", val), + Self::Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), // The `display_timestamp` function formats a timestamp in the ISO 8601 format: "YYYY-MM-DDTHH:MM:SS.fffff". // ANSI SQL standard uses a space instead of 'T'. Some databases do not support 'T', hence it's replaced with a space. // Reference: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.html - Timestamp(val, tu, tz) => write!( + Self::Timestamp(val, tu, tz) => write!( buffer, "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") ), // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, + Self::Decimal(..) | Self::Series(..) | Self::Time(..) | Self::Binary(..) => { + display_sql_err + } #[cfg(feature = "python")] - Python(..) => display_sql_err, - Struct(..) => display_sql_err, + Self::Python(..) => display_sql_err, + Self::Struct(..) => display_sql_err, } } diff --git a/src/daft-dsl/src/resolve_expr.rs b/src/daft-dsl/src/resolve_expr/mod.rs similarity index 75% rename from src/daft-dsl/src/resolve_expr.rs rename to src/daft-dsl/src/resolve_expr/mod.rs index df686f60a0..20f09759d4 100644 --- a/src/daft-dsl/src/resolve_expr.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::{ cmp::Ordering, collections::{BinaryHeap, HashMap}, @@ -410,148 +413,3 @@ pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { - let struct_expr_map = calculate_struct_expr_map(schema); - transform_struct_gets(expr, &struct_expr_map) - } - - #[test] - fn test_substitute_expr_getter_sugar() -> DaftResult<()> { - use crate::functions::struct_::get as struct_get; - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); - - assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); - assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); - assert!(matches!( - substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), - DaftError::ValueError(..) - )); - - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new("b", DataType::Int64)]), - )])?); - - assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - struct_get(col("a"), "b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, - struct_get(col("a"), "b").alias("c") - ); - - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new( - "b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - )]), - )])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - struct_get(col("a"), "b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(struct_get(col("a"), "b"), "c") - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new( - "b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - )]), - ), - Field::new("a.b", DataType::Int64), - ])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - col("a.b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(struct_get(col("a"), "b"), "c") - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), - ), - Field::new( - "a.b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - ), - ])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(col("a.b"), "c") - ); - - Ok(()) - } - - #[test] - fn test_find_wildcards() -> DaftResult<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new("b.*", DataType::Int64)]), - ), - Field::new("c.*", DataType::Int64), - ])?; - let struct_expr_map = calculate_struct_expr_map(&schema); - - let wildcards = find_wildcards(col("test"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "*"); - - let wildcards = find_wildcards(col("t*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); - - let wildcards = find_wildcards(col("a.*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.*"); - - let wildcards = find_wildcards(col("c.*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("a.b.*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("a.b*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.b*"); - - // nested expression - let wildcards = find_wildcards(col("t*").add(col("a.*")), &struct_expr_map); - assert!(wildcards.len() == 2); - assert!(wildcards.iter().any(|s| s.as_ref() == "t*")); - assert!(wildcards.iter().any(|s| s.as_ref() == "a.*")); - - let wildcards = find_wildcards(col("t*").add(col("a")), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); - - // schema containing * - let schema = Schema::new(vec![Field::new("*", DataType::Int64)])?; - let struct_expr_map = calculate_struct_expr_map(&schema); - - let wildcards = find_wildcards(col("*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - Ok(()) - } -} diff --git a/src/daft-dsl/src/resolve_expr/tests.rs b/src/daft-dsl/src/resolve_expr/tests.rs new file mode 100644 index 0000000000..dcb3147207 --- /dev/null +++ b/src/daft-dsl/src/resolve_expr/tests.rs @@ -0,0 +1,141 @@ +use super::*; + +fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { + let struct_expr_map = calculate_struct_expr_map(schema); + transform_struct_gets(expr, &struct_expr_map) +} + +#[test] +fn test_substitute_expr_getter_sugar() -> DaftResult<()> { + use crate::functions::struct_::get as struct_get; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); + assert!(matches!( + substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), + DaftError::ValueError(..) + )); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new("b", DataType::Int64)]), + )])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, + struct_get(col("a"), "b").alias("c") + ); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + )])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + ), + Field::new("a.b", DataType::Int64), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + col("a.b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), + ), + Field::new( + "a.b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + ), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(col("a.b"), "c") + ); + + Ok(()) +} + +#[test] +fn test_find_wildcards() -> DaftResult<()> { + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.*", DataType::Int64)]), + ), + Field::new("c.*", DataType::Int64), + ])?; + let struct_expr_map = calculate_struct_expr_map(&schema); + + let wildcards = find_wildcards(col("test"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "*"); + + let wildcards = find_wildcards(col("t*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); + + let wildcards = find_wildcards(col("a.*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.*"); + + let wildcards = find_wildcards(col("c.*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("a.b.*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("a.b*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.b*"); + + // nested expression + let wildcards = find_wildcards(col("t*").add(col("a.*")), &struct_expr_map); + assert!(wildcards.len() == 2); + assert!(wildcards.iter().any(|s| s.as_ref() == "t*")); + assert!(wildcards.iter().any(|s| s.as_ref() == "a.*")); + + let wildcards = find_wildcards(col("t*").add(col("a")), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); + + // schema containing * + let schema = Schema::new(vec![Field::new("*", DataType::Int64)])?; + let struct_expr_map = calculate_struct_expr_map(&schema); + + let wildcards = find_wildcards(col("*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + Ok(()) +} From a53dfaa61a61b6e8c99df3dbefe8553074ca5278 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Sun, 6 Oct 2024 16:35:05 -0700 Subject: [PATCH 02/28] Add all structure for stddev --- daft/daft/__init__.pyi | 2 + daft/dataframe/dataframe.py | 22 +++ daft/expressions/expressions.py | 5 + daft/series.py | 4 + src/daft-core/src/datatypes/agg_ops.rs | 2 +- src/daft-core/src/datatypes/mod.rs | 2 +- src/daft-dsl/src/expr/mod.rs | 16 ++- src/daft-dsl/src/resolve_expr/mod.rs | 51 +++---- src/daft-functions/src/list/mean.rs | 4 +- src/daft-plan/src/logical_ops/project.rs | 5 + .../src/physical_planner/translate.rs | 128 +++++++++++------- src/daft-sql/src/modules/aggs.rs | 1 + src/daft-table/src/lib.rs | 1 + 13 files changed, 160 insertions(+), 83 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index c90817dfc2..6e91492cc3 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1051,6 +1051,7 @@ class PyExpr: def approx_count_distinct(self) -> PyExpr: ... def approx_percentiles(self, percentiles: float | list[float]) -> PyExpr: ... def mean(self) -> PyExpr: ... + def stddev(self) -> PyExpr: ... def min(self) -> PyExpr: ... def max(self) -> PyExpr: ... def any_value(self, ignore_nulls: bool) -> PyExpr: ... @@ -1334,6 +1335,7 @@ class PySeries: def count(self, mode: CountMode) -> PySeries: ... def sum(self) -> PySeries: ... def mean(self) -> PySeries: ... + def stddev(self) -> PySeries: ... def min(self) -> PySeries: ... def max(self) -> PySeries: ... def agg_list(self) -> PySeries: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 6211423e94..114c4a598f 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2118,6 +2118,17 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": """ return self._apply_agg_fn(Expression.mean, cols) + @DataframePublicAPI + def stddev(self, *cols: ColumnInputType) -> "DataFrame": + """Performs a global standard deviation on the DataFrame + + Args: + *cols (Union[str, Expression]): columns to stddev + Returns: + DataFrame: Globally aggregated standard deviation. Should be a single row. + """ + return self._apply_agg_fn(Expression.stddev, cols) + @DataframePublicAPI def min(self, *cols: ColumnInputType) -> "DataFrame": """Performs a global min on the DataFrame @@ -2856,6 +2867,17 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": """ return self.df._apply_agg_fn(Expression.mean, cols, self.group_by) + def stddev(self, *cols: ColumnInputType) -> "DataFrame": + """Performs grouped standard deviation on this GroupedDataFrame. + + Args: + *cols (Union[str, Expression]): columns to stddev + + Returns: + DataFrame: DataFrame with grouped standard deviation. + """ + return self.df._apply_agg_fn(Expression.stddev, cols, self.group_by) + def min(self, *cols: ColumnInputType) -> "DataFrame": """Perform grouped min on this GroupedDataFrame. diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2701aebc77..0fb885175e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -855,6 +855,11 @@ def mean(self) -> Expression: expr = self._expr.mean() return Expression._from_pyexpr(expr) + def stddev(self) -> Expression: + """Calculates the standard deviation of the values in the expression""" + expr = self._expr.stddev() + return Expression._from_pyexpr(expr) + def min(self) -> Expression: """Calculates the minimum value in the expression""" expr = self._expr.min() diff --git a/daft/series.py b/daft/series.py index 15c5295b4c..5cbcfe7ba0 100644 --- a/daft/series.py +++ b/daft/series.py @@ -512,6 +512,10 @@ def mean(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.mean()) + def stddev(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.stddev()) + def sum(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.sum()) diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs index a6420b039b..53f0f19536 100644 --- a/src/daft-core/src/datatypes/agg_ops.rs +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -23,7 +23,7 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { } /// Get the data type that the mean of a column of the given data type should be casted to. -pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { +pub fn try_numeric_aggregation_supertype(dtype: &DataType) -> DaftResult { if dtype.is_numeric() { Ok(DataType::Float64) } else { diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 174098ada9..36a010a9bd 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -6,7 +6,7 @@ pub use infer_datatype::InferDataType; pub mod prelude; use std::ops::{Add, Div, Mul, Rem, Sub}; -pub use agg_ops::{try_mean_supertype, try_sum_supertype}; +pub use agg_ops::{try_numeric_aggregation_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index b99813fc9d..c47456ae1d 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -10,7 +10,7 @@ use common_error::{DaftError, DaftResult}; use common_hashable_float_wrapper::FloatWrapper; use common_treenode::TreeNode; use daft_core::{ - datatypes::{try_mean_supertype, try_sum_supertype, InferDataType}, + datatypes::{try_numeric_aggregation_supertype, try_sum_supertype, InferDataType}, prelude::*, utils::supertype::try_get_supertype, }; @@ -124,6 +124,9 @@ pub enum AggExpr { #[display("mean({_0})")] Mean(ExprRef), + #[display("stddev({_0})")] + Stddev(ExprRef), + #[display("min({_0})")] Min(ExprRef), @@ -170,6 +173,7 @@ impl AggExpr { | Self::ApproxSketch(expr, _) | Self::MergeSketch(expr, _) | Self::Mean(expr) + | Self::Stddev(expr) | Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) @@ -216,6 +220,10 @@ impl AggExpr { "{child_id}.local_merge_sketch(sketch_type={sketch_type:?})" )) } + Self::Stddev(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.local_stddev()")) + } Self::Mean(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_mean()")) @@ -255,6 +263,7 @@ impl AggExpr { | Self::ApproxSketch(expr, _) | Self::MergeSketch(expr, _) | Self::Mean(expr) + | Self::Stddev(expr) | Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) @@ -275,6 +284,7 @@ impl AggExpr { Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode), Self::Sum(_) => Self::Sum(first_child()), Self::Mean(_) => Self::Mean(first_child()), + Self::Stddev(_) => Self::Stddev(first_child()), Self::Min(_) => Self::Min(first_child()), Self::Max(_) => Self::Max(first_child()), Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls), @@ -372,11 +382,11 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - Self::Mean(expr) => { + Self::Mean(expr) | Self::Stddev(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - try_mean_supertype(&field.dtype)?, + try_numeric_aggregation_supertype(&field.dtype)?, )) } Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) => { diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index 20f09759d4..e2537ff03f 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -206,44 +206,47 @@ fn expand_wildcards( } fn extract_agg_expr(expr: &Expr) -> DaftResult { - use crate::Expr::*; - match expr { - Agg(agg_expr) => Ok(agg_expr.clone()), - Function { func, inputs } => Ok(AggExpr::MapGroups { + Expr::Agg(agg_expr) => Ok(agg_expr.clone()), + Expr::Function { func, inputs } => Ok(AggExpr::MapGroups { func: func.clone(), inputs: inputs.clone(), }), - Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { - use crate::AggExpr::*; - + Expr::Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { // reorder expressions so that alias goes before agg match agg_expr { - Count(e, count_mode) => Count(Alias(e, name.clone()).into(), count_mode), - Sum(e) => Sum(Alias(e, name.clone()).into()), - ApproxPercentile(ApproxPercentileParams { + AggExpr::Count(e, count_mode) => { + AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) + } + AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), + AggExpr::ApproxPercentile(ApproxPercentileParams { child: e, percentiles, force_list_output, - }) => ApproxPercentile(ApproxPercentileParams { - child: Alias(e, name.clone()).into(), + }) => AggExpr::ApproxPercentile(ApproxPercentileParams { + child: Expr::Alias(e, name.clone()).into(), percentiles, force_list_output, }), - ApproxCountDistinct(e) => ApproxCountDistinct(Alias(e, name.clone()).into()), - ApproxSketch(e, sketch_type) => { - ApproxSketch(Alias(e, name.clone()).into(), sketch_type) + AggExpr::ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(Expr::Alias(e, name.clone()).into()) + } + AggExpr::ApproxSketch(e, sketch_type) => { + AggExpr::ApproxSketch(Expr::Alias(e, name.clone()).into(), sketch_type) + } + AggExpr::MergeSketch(e, sketch_type) => { + AggExpr::MergeSketch(Expr::Alias(e, name.clone()).into(), sketch_type) } - MergeSketch(e, sketch_type) => { - MergeSketch(Alias(e, name.clone()).into(), sketch_type) + AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), + AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), + AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), + AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), + AggExpr::AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(Expr::Alias(e, name.clone()).into(), ignore_nulls) } - Mean(e) => Mean(Alias(e, name.clone()).into()), - Min(e) => Min(Alias(e, name.clone()).into()), - Max(e) => Max(Alias(e, name.clone()).into()), - AnyValue(e, ignore_nulls) => AnyValue(Alias(e, name.clone()).into(), ignore_nulls), - List(e) => List(Alias(e, name.clone()).into()), - Concat(e) => Concat(Alias(e, name.clone()).into()), - MapGroups { func, inputs } => MapGroups { + AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()), + AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()), + AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups { func, inputs: inputs .into_iter() diff --git a/src/daft-functions/src/list/mean.rs b/src/daft-functions/src/list/mean.rs index 16a817a9c3..1396d304d6 100644 --- a/src/daft-functions/src/list/mean.rs +++ b/src/daft-functions/src/list/mean.rs @@ -1,6 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::{ - datatypes::try_mean_supertype, + datatypes::try_numeric_aggregation_supertype, prelude::{Field, Schema}, series::Series, }; @@ -29,7 +29,7 @@ impl ScalarUDF for ListMean { let inner_field = input.to_field(schema)?.to_exploded_field()?; Ok(Field::new( inner_field.name.as_str(), - try_mean_supertype(&inner_field.dtype)?, + try_numeric_aggregation_supertype(&inner_field.dtype)?, )) } _ => Err(DaftError::SchemaMismatch(format!( diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 41101fcd17..80cd923dd9 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -411,6 +411,11 @@ fn replace_column_with_semantic_id_aggexpr( |_| e.clone(), ) } + AggExpr::Stddev(ref _child) => { + todo!("stddev") + // replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + // .map_yes_no(AggExpr::Mean, |_| e.clone()) + } AggExpr::Mean(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Mean, |_| e.clone()) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 639c571871..f3ee4eaeed 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -8,7 +8,9 @@ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_file_formats::FileFormat; use daft_core::prelude::*; -use daft_dsl::{col, is_partition_compatible, ApproxPercentileParams, ExprRef, SketchType}; +use daft_dsl::{ + col, is_partition_compatible, AggExpr, ApproxPercentileParams, ExprRef, SketchType, +}; use daft_scan::PhysicalScanInfo; use crate::{ @@ -765,8 +767,6 @@ pub fn populate_aggregation_stages( HashMap, daft_dsl::AggExpr>, Vec, ) { - use daft_dsl::AggExpr::{self, *}; - // Aggregations to apply in the first and second stages. // Semantic column name -> AggExpr let mut first_stage_aggs: HashMap, AggExpr> = HashMap::new(); @@ -777,144 +777,168 @@ pub fn populate_aggregation_stages( for agg_expr in aggregations { let output_name = agg_expr.name(); match agg_expr { - Count(e, mode) => { + AggExpr::Count(e, mode) => { let count_id = agg_expr.semantic_id(schema).id; - let sum_of_count_id = Sum(col(count_id.clone())).semantic_id(schema).id; + let sum_of_count_id = AggExpr::Sum(col(count_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(count_id.clone()) - .or_insert(Count(e.alias(count_id.clone()).clone(), *mode)); + .or_insert(AggExpr::Count(e.alias(count_id.clone()).clone(), *mode)); second_stage_aggs .entry(sum_of_count_id.clone()) - .or_insert(Sum(col(count_id.clone()).alias(sum_of_count_id.clone()))); + .or_insert(AggExpr::Sum( + col(count_id.clone()).alias(sum_of_count_id.clone()), + )); final_exprs.push(col(sum_of_count_id.clone()).alias(output_name)); } - Sum(e) => { + AggExpr::Sum(e) => { let sum_id = agg_expr.semantic_id(schema).id; - let sum_of_sum_id = Sum(col(sum_id.clone())).semantic_id(schema).id; + let sum_of_sum_id = AggExpr::Sum(col(sum_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(sum_id.clone()) - .or_insert(Sum(e.alias(sum_id.clone()).clone())); + .or_insert(AggExpr::Sum(e.alias(sum_id.clone()).clone())); second_stage_aggs .entry(sum_of_sum_id.clone()) - .or_insert(Sum(col(sum_id.clone()).alias(sum_of_sum_id.clone()))); + .or_insert(AggExpr::Sum( + col(sum_id.clone()).alias(sum_of_sum_id.clone()), + )); final_exprs.push(col(sum_of_sum_id.clone()).alias(output_name)); } - Mean(e) => { - let sum_id = Sum(e.clone()).semantic_id(schema).id; - let count_id = Count(e.clone(), CountMode::Valid).semantic_id(schema).id; - let sum_of_sum_id = Sum(col(sum_id.clone())).semantic_id(schema).id; - let sum_of_count_id = Sum(col(count_id.clone())).semantic_id(schema).id; + AggExpr::Mean(e) => { + let sum_id = AggExpr::Sum(e.clone()).semantic_id(schema).id; + let count_id = AggExpr::Count(e.clone(), CountMode::Valid) + .semantic_id(schema) + .id; + let sum_of_sum_id = AggExpr::Sum(col(sum_id.clone())).semantic_id(schema).id; + let sum_of_count_id = AggExpr::Sum(col(count_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(sum_id.clone()) - .or_insert(Sum(e.alias(sum_id.clone()).clone())); + .or_insert(AggExpr::Sum(e.alias(sum_id.clone()).clone())); first_stage_aggs .entry(count_id.clone()) - .or_insert(Count(e.alias(count_id.clone()).clone(), CountMode::Valid)); + .or_insert(AggExpr::Count( + e.alias(count_id.clone()).clone(), + CountMode::Valid, + )); second_stage_aggs .entry(sum_of_sum_id.clone()) - .or_insert(Sum(col(sum_id.clone()).alias(sum_of_sum_id.clone()))); + .or_insert(AggExpr::Sum( + col(sum_id.clone()).alias(sum_of_sum_id.clone()), + )); second_stage_aggs .entry(sum_of_count_id.clone()) - .or_insert(Sum(col(count_id.clone()).alias(sum_of_count_id.clone()))); + .or_insert(AggExpr::Sum( + col(count_id.clone()).alias(sum_of_count_id.clone()), + )); final_exprs.push( (col(sum_of_sum_id.clone()).div(col(sum_of_count_id.clone()))) .alias(output_name), ); } - Min(e) => { + AggExpr::Stddev(_expr) => todo!("stddev"), + AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; - let min_of_min_id = Min(col(min_id.clone())).semantic_id(schema).id; + let min_of_min_id = AggExpr::Min(col(min_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(min_id.clone()) - .or_insert(Min(e.alias(min_id.clone()).clone())); + .or_insert(AggExpr::Min(e.alias(min_id.clone()).clone())); second_stage_aggs .entry(min_of_min_id.clone()) - .or_insert(Min(col(min_id.clone()).alias(min_of_min_id.clone()))); + .or_insert(AggExpr::Min( + col(min_id.clone()).alias(min_of_min_id.clone()), + )); final_exprs.push(col(min_of_min_id.clone()).alias(output_name)); } - Max(e) => { + AggExpr::Max(e) => { let max_id = agg_expr.semantic_id(schema).id; - let max_of_max_id = Max(col(max_id.clone())).semantic_id(schema).id; + let max_of_max_id = AggExpr::Max(col(max_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(max_id.clone()) - .or_insert(Max(e.alias(max_id.clone()).clone())); + .or_insert(AggExpr::Max(e.alias(max_id.clone()).clone())); second_stage_aggs .entry(max_of_max_id.clone()) - .or_insert(Max(col(max_id.clone()).alias(max_of_max_id.clone()))); + .or_insert(AggExpr::Max( + col(max_id.clone()).alias(max_of_max_id.clone()), + )); final_exprs.push(col(max_of_max_id.clone()).alias(output_name)); } - AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(e, ignore_nulls) => { let any_id = agg_expr.semantic_id(schema).id; - let any_of_any_id = AnyValue(col(any_id.clone()), *ignore_nulls) + let any_of_any_id = AggExpr::AnyValue(col(any_id.clone()), *ignore_nulls) .semantic_id(schema) .id; first_stage_aggs .entry(any_id.clone()) - .or_insert(AnyValue(e.alias(any_id.clone()).clone(), *ignore_nulls)); + .or_insert(AggExpr::AnyValue( + e.alias(any_id.clone()).clone(), + *ignore_nulls, + )); second_stage_aggs .entry(any_of_any_id.clone()) - .or_insert(AnyValue( + .or_insert(AggExpr::AnyValue( col(any_id.clone()).alias(any_of_any_id.clone()), *ignore_nulls, )); final_exprs.push(col(any_of_any_id.clone()).alias(output_name)); } - List(e) => { + AggExpr::List(e) => { let list_id = agg_expr.semantic_id(schema).id; - let concat_of_list_id = Concat(col(list_id.clone())).semantic_id(schema).id; + let concat_of_list_id = + AggExpr::Concat(col(list_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(list_id.clone()) - .or_insert(List(e.alias(list_id.clone()).clone())); + .or_insert(AggExpr::List(e.alias(list_id.clone()).clone())); second_stage_aggs .entry(concat_of_list_id.clone()) - .or_insert(Concat( + .or_insert(AggExpr::Concat( col(list_id.clone()).alias(concat_of_list_id.clone()), )); final_exprs.push(col(concat_of_list_id.clone()).alias(output_name)); } - Concat(e) => { + AggExpr::Concat(e) => { let concat_id = agg_expr.semantic_id(schema).id; - let concat_of_concat_id = Concat(col(concat_id.clone())).semantic_id(schema).id; + let concat_of_concat_id = AggExpr::Concat(col(concat_id.clone())) + .semantic_id(schema) + .id; first_stage_aggs .entry(concat_id.clone()) - .or_insert(Concat(e.alias(concat_id.clone()).clone())); + .or_insert(AggExpr::Concat(e.alias(concat_id.clone()).clone())); second_stage_aggs .entry(concat_of_concat_id.clone()) - .or_insert(Concat( + .or_insert(AggExpr::Concat( col(concat_id.clone()).alias(concat_of_concat_id.clone()), )); final_exprs.push(col(concat_of_concat_id.clone()).alias(output_name)); } - MapGroups { func, inputs } => { + AggExpr::MapGroups { func, inputs } => { let func_id = agg_expr.semantic_id(schema).id; // No first stage aggregation for MapGroups, do all the work in the second stage. second_stage_aggs .entry(func_id.clone()) - .or_insert(MapGroups { + .or_insert(AggExpr::MapGroups { func: func.clone(), inputs: inputs.to_vec(), }); final_exprs.push(col(output_name)); } - &ApproxPercentile(ApproxPercentileParams { + &AggExpr::ApproxPercentile(ApproxPercentileParams { child: ref e, ref percentiles, force_list_output, }) => { let percentiles = percentiles.iter().map(|p| p.0).collect::>(); let sketch_id = agg_expr.semantic_id(schema).id; - let approx_id = ApproxSketch(col(sketch_id.clone()), SketchType::DDSketch) + let approx_id = AggExpr::ApproxSketch(col(sketch_id.clone()), SketchType::DDSketch) .semantic_id(schema) .id; first_stage_aggs .entry(sketch_id.clone()) - .or_insert(ApproxSketch( + .or_insert(AggExpr::ApproxSketch( e.alias(sketch_id.clone()), SketchType::DDSketch, )); second_stage_aggs .entry(approx_id.clone()) - .or_insert(MergeSketch( + .or_insert(AggExpr::MergeSketch( col(sketch_id.clone()).alias(approx_id.clone()), SketchType::DDSketch, )); @@ -924,30 +948,30 @@ pub fn populate_aggregation_stages( .alias(output_name), ); } - ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(e) => { let first_stage_id = agg_expr.semantic_id(schema).id; let second_stage_id = - MergeSketch(col(first_stage_id.clone()), SketchType::HyperLogLog) + AggExpr::MergeSketch(col(first_stage_id.clone()), SketchType::HyperLogLog) .semantic_id(schema) .id; first_stage_aggs .entry(first_stage_id.clone()) - .or_insert(ApproxSketch( + .or_insert(AggExpr::ApproxSketch( e.alias(first_stage_id.clone()), SketchType::HyperLogLog, )); second_stage_aggs .entry(second_stage_id.clone()) - .or_insert(MergeSketch( + .or_insert(AggExpr::MergeSketch( col(first_stage_id).alias(second_stage_id.clone()), SketchType::HyperLogLog, )); final_exprs.push(col(second_stage_id).alias(output_name)); } - ApproxSketch(..) => { + AggExpr::ApproxSketch(..) => { unimplemented!("User-facing approx_sketch aggregation is not implemented") } - MergeSketch(..) => { + AggExpr::MergeSketch(..) => { unimplemented!("User-facing merge_sketch aggregation is not implemented") } } diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index 0fbd2f7067..f3bbe81a30 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -109,6 +109,7 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult todo!(), AggExpr::Min(_) => { ensure!(args.len() == 1, "min takes exactly one argument"); Ok(args[0].clone().min()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 3669fda3f5..148b02f32f 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -480,6 +480,7 @@ impl Table { } } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), + AggExpr::Stddev(_expr) => todo!("stddev"), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => { From 2ae9d0844ee146fae0ae4ef0b76f9e2cf063d921 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Sun, 6 Oct 2024 17:12:11 -0700 Subject: [PATCH 03/28] Implement structure for local and distributed stddev --- src/daft-core/src/array/ops/mod.rs | 7 +++++ src/daft-core/src/array/ops/stddev.rs | 16 ++++++++++ src/daft-core/src/series/ops/agg.rs | 37 ++++++++++++------------ src/daft-plan/src/logical_ops/project.rs | 31 ++++++++++---------- src/daft-schema/src/dtype.rs | 12 ++++++++ src/daft-table/src/lib.rs | 2 +- 6 files changed, 70 insertions(+), 35 deletions(-) create mode 100644 src/daft-core/src/array/ops/stddev.rs diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index d3a940f376..3bcf0f0cb9 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -49,6 +49,7 @@ mod sketch_percentile; mod sort; pub(crate) mod sparse_tensor; mod sqrt; +mod stddev; mod struct_; mod sum; mod take; @@ -189,6 +190,12 @@ pub trait DaftMeanAggable { fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output; } +pub trait DaftStddevAggable { + type Output; + fn stddev(&self) -> Self::Output; + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output; +} + pub trait DaftCompareAggable { type Output; fn min(&self) -> Self::Output; diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs new file mode 100644 index 0000000000..cf73fe6ff5 --- /dev/null +++ b/src/daft-core/src/array/ops/stddev.rs @@ -0,0 +1,16 @@ +use common_error::DaftResult; + +use super::{DaftStddevAggable, GroupIndices}; +use crate::{array::DataArray, datatypes::Float64Type}; + +impl DaftStddevAggable for DataArray { + type Output = DaftResult; + + fn stddev(&self) -> Self::Output { + todo!("stddev") + } + + fn grouped_stddev(&self, _: &GroupIndices) -> Self::Output { + todo!("stddev") + } +} diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 541fe5c556..79cfaa484a 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -4,7 +4,7 @@ use logical::Decimal128Array; use crate::{ array::{ - ops::{DaftHllMergeAggable, GroupIndices}, + ops::{DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, GroupIndices}, ListArray, }, count_mode::CountMode, @@ -149,24 +149,25 @@ impl Series { } pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*}; - // Upcast all numeric types to float64 and use f64 mean kernel. - match self.data_type() { - dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; - match groups { - Some(groups) => { - Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series()) - } - None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()), - } - } - other => Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for type {}", - other - ))), - } + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))? + .into_series(); + Ok(series) + } + + pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult { + // Upcast all numeric types to float64 and use f64 stddev kernel. + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))? + .into_series(); + Ok(series) } pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 80cd923dd9..e3e956cdaf 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -373,24 +373,24 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::Count(ref child, mode) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::Count(transformed_child, mode), - |_| e.clone(), + |_| e, ) } AggExpr::Sum(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Sum, |_| e.clone()) + .map_yes_no(AggExpr::Sum, |_| e) } AggExpr::ApproxPercentile(ApproxPercentileParams { ref child, ref percentiles, - ref force_list_output, + force_list_output, }) => replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no( |transformed_child| { AggExpr::ApproxPercentile(ApproxPercentileParams { child: transformed_child, percentiles: percentiles.clone(), - force_list_output: *force_list_output, + force_list_output, }) }, |_| e.clone(), @@ -402,45 +402,44 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::ApproxSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::ApproxSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } AggExpr::MergeSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::MergeSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } - AggExpr::Stddev(ref _child) => { - todo!("stddev") - // replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - // .map_yes_no(AggExpr::Mean, |_| e.clone()) + AggExpr::Stddev(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::Stddev, |_| e) } AggExpr::Mean(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Mean, |_| e.clone()) + .map_yes_no(AggExpr::Mean, |_| e) } AggExpr::Min(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Min, |_| e.clone()) + .map_yes_no(AggExpr::Min, |_| e) } AggExpr::Max(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Max, |_| e.clone()) + .map_yes_no(AggExpr::Max, |_| e) } AggExpr::AnyValue(ref child, ignore_nulls) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls), - |_| e.clone(), + |_| e, ) } AggExpr::List(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::List, |_| e.clone()) + .map_yes_no(AggExpr::List, |_| e) } AggExpr::Concat(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Concat, |_| e.clone()) + .map_yes_no(AggExpr::Concat, |_| e) } AggExpr::MapGroups { func, inputs } => { let transforms = inputs diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 65cf8f808e..2461aa6287 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -367,6 +367,18 @@ impl DataType { } } + #[inline] + pub fn assert_is_numeric(&self) -> DaftResult<()> { + if self.is_numeric() { + Ok(()) + } else { + Err(DaftError::TypeError(format!( + "Numeric mean is not implemented for type {}", + self, + ))) + } + } + #[inline] pub fn is_fixed_size_numeric(&self) -> bool { match self { diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 148b02f32f..eff28c6a26 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -480,7 +480,7 @@ impl Table { } } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), - AggExpr::Stddev(_expr) => todo!("stddev"), + AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => { From c7f189c26f77b730cbbac96aa1e6c37f01cf8c1c Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 08:29:32 -0700 Subject: [PATCH 04/28] Implement non-grouped stddev - factored out some common logic into a util::stats module - refactored mean to use the new module --- src/daft-core/src/array/ops/mean.rs | 28 ++++++++++++--------------- src/daft-core/src/array/ops/stddev.rs | 27 +++++++++++++++++++++++--- src/daft-core/src/utils/mod.rs | 1 + src/daft-core/src/utils/stats.rs | 25 ++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 19 deletions(-) create mode 100644 src/daft-core/src/utils/stats.rs diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index b4b4016bbc..83896db697 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -1,28 +1,24 @@ use std::sync::Arc; +use arrow2::array::PrimitiveArray; use common_error::DaftResult; use super::{as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable}; -use crate::{array::ops::GroupIndices, count_mode::CountMode, datatypes::*}; +use crate::{ + array::ops::GroupIndices, + count_mode::CountMode, + datatypes::*, + utils::stats::{stats, Stats}, +}; impl DaftMeanAggable for &DataArray { type Output = DaftResult>; fn mean(&self) -> Self::Output { - let sum_value = DaftSumAggable::sum(self)?.as_arrow().value(0); - let count_value = DaftCountAggable::count(self, CountMode::Valid)? - .as_arrow() - .value(0); - - let result = match count_value { - 0 => None, - count_value => Some(sum_value / count_value as f64), - }; - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); - - DataArray::new( - Arc::new(Field::new(self.field.name.clone(), DataType::Float64)), - arrow_array, - ) + let Stats { mean, count, .. } = stats(self)?; + let value = mean.map(|mean| mean / count as f64); + let data = PrimitiveArray::from([value]).boxed(); + let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + DataArray::new(field, data) } fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output { diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs index cf73fe6ff5..525136ae31 100644 --- a/src/daft-core/src/array/ops/stddev.rs +++ b/src/daft-core/src/array/ops/stddev.rs @@ -1,13 +1,34 @@ +use std::sync::Arc; + +use arrow2::array::PrimitiveArray; use common_error::DaftResult; +use daft_schema::{dtype::DataType, field::Field}; -use super::{DaftStddevAggable, GroupIndices}; -use crate::{array::DataArray, datatypes::Float64Type}; +use crate::{ + array::{ + ops::{DaftStddevAggable, GroupIndices}, + DataArray, + }, + datatypes::Float64Type, + utils::stats::{stats, Stats}, +}; impl DaftStddevAggable for DataArray { type Output = DaftResult; fn stddev(&self) -> Self::Output { - todo!("stddev") + let Stats { count, mean, .. } = stats(self)?; + let stddev = mean.map(|mean| { + let mut square_sum = 0.0; + for &value in self.into_iter().flatten() { + square_sum += (value - mean).powi(2); + } + let variance = square_sum / count as f64; + variance.sqrt() + }); + let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + let data = PrimitiveArray::::from([stddev]).boxed(); + Self::new(field, data) } fn grouped_stddev(&self, _: &GroupIndices) -> Self::Output { diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index 2e039e6953..baf1dc66fd 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -2,4 +2,5 @@ pub mod arrow; pub mod display; pub mod dyn_compare; pub mod identity_hash_set; +pub mod stats; pub mod supertype; diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs new file mode 100644 index 0000000000..e000d6b605 --- /dev/null +++ b/src/daft-core/src/utils/stats.rs @@ -0,0 +1,25 @@ +use common_error::DaftResult; + +use crate::{ + array::{ + ops::{DaftCountAggable, DaftSumAggable}, + prelude::Float64Array, + }, + count_mode::CountMode, +}; + +pub struct Stats { + pub sum: f64, + pub count: u64, + pub mean: Option, +} + +pub fn stats(array: &Float64Array) -> DaftResult { + let sum = array.sum()?.get(0).unwrap(); + let count = array.count(CountMode::Valid)?.get(0).unwrap(); + let mean = match count { + 0 => None, + _ => Some(sum / count as f64), + }; + Ok(Stats { sum, count, mean }) +} From 797358f52481096258dfa0d251a9bb33fc922ea6 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 09:26:19 -0700 Subject: [PATCH 05/28] Implement grouped standard deviation - some additional refactors to `utils::stats` module --- src/daft-core/src/array/ops/mean.rs | 23 +++++----- src/daft-core/src/array/ops/stddev.rs | 24 +++++----- src/daft-core/src/utils/stats.rs | 63 ++++++++++++++++++++++++--- 3 files changed, 81 insertions(+), 29 deletions(-) diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index 83896db697..1bc659279a 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -3,26 +3,27 @@ use std::sync::Arc; use arrow2::array::PrimitiveArray; use common_error::DaftResult; -use super::{as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable}; use crate::{ - array::ops::GroupIndices, + array::ops::{ + as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable, GroupIndices, + }, count_mode::CountMode, datatypes::*, - utils::stats::{stats, Stats}, + utils::stats, }; -impl DaftMeanAggable for &DataArray { - type Output = DaftResult>; + +impl DaftMeanAggable for DataArray { + type Output = DaftResult; fn mean(&self) -> Self::Output { - let Stats { mean, count, .. } = stats(self)?; - let value = mean.map(|mean| mean / count as f64); - let data = PrimitiveArray::from([value]).boxed(); + let stats = stats::calculate_stats(self)?; + let mean = stats::calculate_mean(stats.sum, stats.count); + let data = PrimitiveArray::from([mean]).boxed(); let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); - DataArray::new(field, data) + Self::new(field, data) } fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output { - use arrow2::array::PrimitiveArray; let sum_values = self.grouped_sum(groups)?; let count_values = self.grouped_count(groups, CountMode::Valid)?; assert_eq!(sum_values.len(), count_values.len()); @@ -35,6 +36,6 @@ impl DaftMeanAggable for &DataArray { (s, c) => Some(s / (*c as f64)), }); let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group)); - Ok(DataArray::from((self.field.name.as_ref(), mean_array))) + Ok(Self::from((self.field.name.as_ref(), mean_array))) } } diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs index 525136ae31..42692b62ba 100644 --- a/src/daft-core/src/array/ops/stddev.rs +++ b/src/daft-core/src/array/ops/stddev.rs @@ -10,28 +10,28 @@ use crate::{ DataArray, }, datatypes::Float64Type, - utils::stats::{stats, Stats}, + utils::stats, }; impl DaftStddevAggable for DataArray { type Output = DaftResult; fn stddev(&self) -> Self::Output { - let Stats { count, mean, .. } = stats(self)?; - let stddev = mean.map(|mean| { - let mut square_sum = 0.0; - for &value in self.into_iter().flatten() { - square_sum += (value - mean).powi(2); - } - let variance = square_sum / count as f64; - variance.sqrt() - }); + let stats = stats::calculate_stats(self)?; + let values = self.into_iter().flatten().copied(); + let stddev = stats::calculate_stddev(stats, values); let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); let data = PrimitiveArray::::from([stddev]).boxed(); Self::new(field, data) } - fn grouped_stddev(&self, _: &GroupIndices) -> Self::Output { - todo!("stddev") + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output { + let grouped_stddevs_iter = stats::grouped_stats(self, groups)?.map(|(stats, group)| { + let values = group.iter().filter_map(|&index| self.get(index as _)); + stats::calculate_stddev(stats, values) + }); + let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + let data = PrimitiveArray::::from_iter(grouped_stddevs_iter).boxed(); + Self::new(field, data) } } diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs index e000d6b605..41aa2e5e26 100644 --- a/src/daft-core/src/utils/stats.rs +++ b/src/daft-core/src/utils/stats.rs @@ -2,8 +2,8 @@ use common_error::DaftResult; use crate::{ array::{ - ops::{DaftCountAggable, DaftSumAggable}, - prelude::Float64Array, + ops::{DaftCountAggable, DaftSumAggable, GroupIndices, VecIndices}, + prelude::{Float64Array, UInt64Array}, }, count_mode::CountMode, }; @@ -14,12 +14,63 @@ pub struct Stats { pub mean: Option, } -pub fn stats(array: &Float64Array) -> DaftResult { +pub fn calculate_stats(array: &Float64Array) -> DaftResult { let sum = array.sum()?.get(0).unwrap(); let count = array.count(CountMode::Valid)?.get(0).unwrap(); - let mean = match count { + let mean = calculate_mean(sum, count); + Ok(Stats { sum, count, mean }) +} + +pub fn grouped_stats<'a>( + array: &Float64Array, + groups: &'a GroupIndices, +) -> DaftResult> { + let grouped_sum = array.grouped_sum(groups)?; + let grouped_count = array.grouped_count(groups, CountMode::Valid)?; + assert_eq!(grouped_sum.len(), grouped_count.len()); + assert_eq!(grouped_sum.len(), groups.len()); + Ok(GroupedStats { + grouped_sum, + grouped_count, + groups: groups.iter().enumerate(), + }) +} + +struct GroupedStats<'a, I: Iterator> { + grouped_sum: Float64Array, + grouped_count: UInt64Array, + groups: I, +} + +impl<'a, I: Iterator> Iterator for GroupedStats<'a, I> { + type Item = (Stats, &'a VecIndices); + + fn next(&mut self) -> Option { + let (index, group) = self.groups.next()?; + let sum = self + .grouped_sum + .get(index) + .expect("All values in `self.grouped_sum` must be valid"); + let count = self + .grouped_count + .get(index) + .expect("All values in `self.grouped_count` must be valid"); + let mean = calculate_mean(sum, count); + let stats = Stats { sum, count, mean }; + Some((stats, group)) + } +} + +pub fn calculate_mean(sum: f64, count: u64) -> Option { + match count { 0 => None, _ => Some(sum / count as f64), - }; - Ok(Stats { sum, count, mean }) + } +} + +pub fn calculate_stddev(stats: Stats, values: impl Iterator) -> Option { + stats.mean.map(|mean| { + let sum_of_squares = values.map(|value| (value - mean).powi(2)).sum::(); + (sum_of_squares / stats.count as f64).sqrt() + }) } From 93b125ba216078debdfc0024abad6fda63234f06 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 11:13:12 -0700 Subject: [PATCH 06/28] Remove unwraps that may have panicked because of invalid first element - summing or counting may result in a `None` first element --- src/daft-core/src/utils/stats.rs | 34 +++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs index 41aa2e5e26..bf429a4d14 100644 --- a/src/daft-core/src/utils/stats.rs +++ b/src/daft-core/src/utils/stats.rs @@ -8,6 +8,7 @@ use crate::{ count_mode::CountMode, }; +#[derive(Clone, Copy, Default, Debug)] pub struct Stats { pub sum: f64, pub count: u64, @@ -15,10 +16,16 @@ pub struct Stats { } pub fn calculate_stats(array: &Float64Array) -> DaftResult { - let sum = array.sum()?.get(0).unwrap(); - let count = array.count(CountMode::Valid)?.get(0).unwrap(); - let mean = calculate_mean(sum, count); - Ok(Stats { sum, count, mean }) + let sum = array.sum()?.get(0); + let count = array.count(CountMode::Valid)?.get(0); + let stats = sum + .zip(count) + .map_or_else(Default::default, |(sum, count)| Stats { + sum, + count, + mean: calculate_mean(sum, count), + }); + Ok(stats) } pub fn grouped_stats<'a>( @@ -47,16 +54,15 @@ impl<'a, I: Iterator> Iterator for GroupedStats< fn next(&mut self) -> Option { let (index, group) = self.groups.next()?; - let sum = self - .grouped_sum - .get(index) - .expect("All values in `self.grouped_sum` must be valid"); - let count = self - .grouped_count - .get(index) - .expect("All values in `self.grouped_count` must be valid"); - let mean = calculate_mean(sum, count); - let stats = Stats { sum, count, mean }; + let sum = self.grouped_sum.get(index); + let count = self.grouped_count.get(index); + let stats = sum + .zip(count) + .map_or_else(Default::default, |(sum, count)| Stats { + sum, + count, + mean: calculate_mean(sum, count), + }); Some((stats, group)) } } From a55ed2bd65b520eaa39b7b6dbce0653621f5e914 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 13:18:51 -0700 Subject: [PATCH 07/28] Add `#[pyfunctions]` functions to code - forgot to add these bindings --- src/daft-dsl/src/expr/mod.rs | 4 ++++ src/daft-dsl/src/python.rs | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index c47456ae1d..300a4f8b64 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -505,6 +505,10 @@ impl Expr { Self::Agg(AggExpr::Mean(self)).into() } + pub fn stddev(self: ExprRef) -> ExprRef { + Self::Agg(AggExpr::Stddev(self)).into() + } + pub fn min(self: ExprRef) -> ExprRef { Self::Agg(AggExpr::Min(self)).into() } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index edd3f5bcb4..17a92f7a7f 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -327,6 +327,10 @@ impl PyExpr { Ok(self.expr.clone().mean().into()) } + pub fn stddev(&self) -> PyResult { + Ok(self.expr.clone().stddev().into()) + } + pub fn min(&self) -> PyResult { Ok(self.expr.clone().min().into()) } From bd509d405987b55a34aa49cd914e0a33be99270d Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 13:48:10 -0700 Subject: [PATCH 08/28] Add basic test for stddev --- tests/dataframe/test_stddev.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/dataframe/test_stddev.py diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py new file mode 100644 index 0000000000..0f4106b292 --- /dev/null +++ b/tests/dataframe/test_stddev.py @@ -0,0 +1,40 @@ +import functools +import math + +import pytest + +import daft + + +def stddev(nums) -> float: + if not nums: + return 0.0 + sum_: float = sum(nums) + count = len(nums) + mean = sum_ / count + squared_sums = functools.reduce(lambda acc, num: acc + (num - mean) ** 2, nums, 0) + stddev = math.sqrt(squared_sums / count) + return stddev + + +TESTS = [ + [nums := [0], stddev(nums)], + [nums := [0, 1, 2], stddev(nums)], + [nums := [0, 0, 0], stddev(nums)], +] + + +@pytest.mark.parametrize("data_and_expected", TESTS) +def test_stddev(data_and_expected): + data, expected = data_and_expected + df = daft.from_pydict({"a": data}) + result = df.agg(daft.col("a").stddev()).collect() + rows = result.iter_rows() + stddev = next(rows) + try: + next(rows) + assert False + except StopIteration: + pass + + assert stddev["a"] == expected From 1b4d0390c7d8b7083ff056c2b4a19e2450e0247e Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 14:42:35 -0700 Subject: [PATCH 09/28] Add partition based testing - this tests the non-singular-partition based implementation --- tests/dataframe/test_stddev.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py index 0f4106b292..044d5facc5 100644 --- a/tests/dataframe/test_stddev.py +++ b/tests/dataframe/test_stddev.py @@ -38,3 +38,19 @@ def test_stddev(data_and_expected): pass assert stddev["a"] == expected + + +@pytest.mark.parametrize("data_and_expected", TESTS) +def test_stddev_with_multiple_partitions(data_and_expected): + data, expected = data_and_expected + df = daft.from_pydict({"a": data}).into_partitions(2) + result = df.agg(daft.col("a").stddev()).collect() + rows = result.iter_rows() + stddev = next(rows) + try: + next(rows) + assert False + except StopIteration: + pass + + assert stddev["a"] == expected From 633f48641764b0019d603a2663902e1e56aac103 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 16:20:14 -0700 Subject: [PATCH 10/28] Add first stage pass to stddev distributed implementation --- .../src/physical_planner/translate.rs | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index f3ee4eaeed..e9790df2a6 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -774,6 +774,17 @@ pub fn populate_aggregation_stages( // Project the aggregation results to their final output names let mut final_exprs: Vec = group_by.iter().map(|e| col(e.name())).collect(); + let get_id = |expr: &AggExpr| expr.semantic_id(schema).id; + + fn add_to_stage(stage: &mut HashMap, AggExpr>, id: Arc, agg_expr: AggExpr) { + let prev_agg_expr = stage.insert(id.clone(), agg_expr); + assert!( + prev_agg_expr.is_none(), + "{:?} already exists in this stage but it should not", + id + ); + } + for agg_expr in aggregations { let output_name = agg_expr.name(); match agg_expr { @@ -834,7 +845,17 @@ pub fn populate_aggregation_stages( .alias(output_name), ); } - AggExpr::Stddev(_expr) => todo!("stddev"), + AggExpr::Stddev(sub_expr) => { + // first stage + let sum_expr = AggExpr::Sum(sub_expr.clone()); + let count_expr = AggExpr::Count(sub_expr.clone(), CountMode::Valid); + add_to_stage(&mut first_stage_aggs, get_id(&sum_expr), sum_expr); + add_to_stage(&mut first_stage_aggs, get_id(&count_expr), count_expr); + + // second stage + + todo!() + } AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; let min_of_min_id = AggExpr::Min(col(min_id.clone())).semantic_id(schema).id; From 9b9ac18509069b8634249533e9661395613cfd48 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 16:33:30 -0700 Subject: [PATCH 11/28] Add `StddevMerge` variant to finish the second stage aggregations --- src/daft-dsl/src/expr/mod.rs | 16 +++++++++++++--- src/daft-dsl/src/resolve_expr/mod.rs | 3 +++ src/daft-plan/src/logical_ops/project.rs | 8 ++++++-- src/daft-plan/src/physical_planner/translate.rs | 3 ++- src/daft-sql/src/modules/aggs.rs | 3 ++- src/daft-table/src/lib.rs | 1 + 6 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 300a4f8b64..034a557278 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -127,6 +127,9 @@ pub enum AggExpr { #[display("stddev({_0})")] Stddev(ExprRef), + #[display("stddev_merge({_0})")] + StddevMerge(ExprRef), + #[display("min({_0})")] Min(ExprRef), @@ -174,6 +177,7 @@ impl AggExpr { | Self::MergeSketch(expr, _) | Self::Mean(expr) | Self::Stddev(expr) + | Self::StddevMerge(expr) | Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) @@ -220,13 +224,17 @@ impl AggExpr { "{child_id}.local_merge_sketch(sketch_type={sketch_type:?})" )) } + Self::Mean(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.local_mean()")) + } Self::Stddev(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_stddev()")) } - Self::Mean(expr) => { + Self::StddevMerge(expr) => { let child_id = expr.semantic_id(schema); - FieldID::new(format!("{child_id}.local_mean()")) + FieldID::new(format!("{child_id}.local_stddev_merge()")) } Self::Min(expr) => { let child_id = expr.semantic_id(schema); @@ -264,6 +272,7 @@ impl AggExpr { | Self::MergeSketch(expr, _) | Self::Mean(expr) | Self::Stddev(expr) + | Self::StddevMerge(expr) | Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) @@ -285,6 +294,7 @@ impl AggExpr { Self::Sum(_) => Self::Sum(first_child()), Self::Mean(_) => Self::Mean(first_child()), Self::Stddev(_) => Self::Stddev(first_child()), + Self::StddevMerge(_) => Self::StddevMerge(first_child()), Self::Min(_) => Self::Min(first_child()), Self::Max(_) => Self::Max(first_child()), Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls), @@ -382,7 +392,7 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - Self::Mean(expr) | Self::Stddev(expr) => { + Self::Mean(expr) | Self::Stddev(expr) | Self::StddevMerge(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index e2537ff03f..8a46faf694 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -239,6 +239,9 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult { } AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), + AggExpr::StddevMerge(e) => { + AggExpr::StddevMerge(Expr::Alias(e, name.clone()).into()) + } AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), AggExpr::AnyValue(e, ignore_nulls) => { diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index e3e956cdaf..5a60bf0aa2 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -411,13 +411,17 @@ fn replace_column_with_semantic_id_aggexpr( |_| e, ) } + AggExpr::Mean(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::Mean, |_| e) + } AggExpr::Stddev(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Stddev, |_| e) } - AggExpr::Mean(ref child) => { + AggExpr::StddevMerge(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Mean, |_| e) + .map_yes_no(AggExpr::StddevMerge, |_| e) } AggExpr::Min(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index e9790df2a6..69980f125c 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -854,8 +854,9 @@ pub fn populate_aggregation_stages( // second stage - todo!() + todo!("stddev") } + AggExpr::StddevMerge(..) => todo!("stddev_merge"), AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; let min_of_min_id = AggExpr::Min(col(min_id.clone())).semantic_id(schema).id; diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index f3bbe81a30..ba03818583 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -109,7 +109,8 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult todo!(), + AggExpr::Stddev(..) => todo!("stddev"), + AggExpr::StddevMerge(..) => todo!("stddev_merge"), AggExpr::Min(_) => { ensure!(args.len() == 1, "min takes exactly one argument"); Ok(args[0].clone().min()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index eff28c6a26..f140f47927 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -481,6 +481,7 @@ impl Table { } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups), + AggExpr::StddevMerge(..) => todo!("stddev merge"), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => { From 9669a080d7baab80ee2987ce90fe88e5575e1c7a Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 16:37:49 -0700 Subject: [PATCH 12/28] Implement `stddev_merge` todo - located in translate.rs - impl'd it with an `unimplemented!()` because a user-facing stddev_merge function is not exposed --- src/daft-plan/src/physical_planner/translate.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 69980f125c..00d9c2efaa 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -856,7 +856,9 @@ pub fn populate_aggregation_stages( todo!("stddev") } - AggExpr::StddevMerge(..) => todo!("stddev_merge"), + AggExpr::StddevMerge(..) => { + unimplemented!("User-facing stddev_merge aggregation is not implemented") + } AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; let min_of_min_id = AggExpr::Min(col(min_id.clone())).semantic_id(schema).id; From 64bff42667c0e04525ec2125f74defa9cdf087ad Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Mon, 7 Oct 2024 20:31:17 -0700 Subject: [PATCH 13/28] Finish distribtued stddev --- src/daft-core/src/array/ops/mod.rs | 7 +++ src/daft-core/src/array/ops/square_sum.rs | 37 ++++++++++++++ src/daft-core/src/array/ops/stddev.rs | 7 +-- src/daft-core/src/series/ops/agg.rs | 38 +++++++++----- src/daft-dsl/src/expr/mod.rs | 24 ++++----- src/daft-dsl/src/resolve_expr/mod.rs | 4 +- src/daft-plan/src/logical_ops/project.rs | 8 +-- .../src/physical_planner/translate.rs | 51 ++++++++++++++++--- src/daft-sql/src/modules/aggs.rs | 7 ++- src/daft-table/src/lib.rs | 2 +- 10 files changed, 137 insertions(+), 48 deletions(-) create mode 100644 src/daft-core/src/array/ops/square_sum.rs diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 3bcf0f0cb9..41081fc118 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -49,6 +49,7 @@ mod sketch_percentile; mod sort; pub(crate) mod sparse_tensor; mod sqrt; +mod square_sum; mod stddev; mod struct_; mod sum; @@ -172,6 +173,12 @@ pub trait DaftSumAggable { fn grouped_sum(&self, groups: &GroupIndices) -> Self::Output; } +pub trait DaftSquareSumAggable { + type Output; + fn square_sum(&self) -> Self::Output; + fn grouped_square_sum(&self, groups: &GroupIndices) -> Self::Output; +} + pub trait DaftApproxSketchAggable { type Output; fn approx_sketch(&self) -> Self::Output; diff --git a/src/daft-core/src/array/ops/square_sum.rs b/src/daft-core/src/array/ops/square_sum.rs new file mode 100644 index 0000000000..7f4b1baf7e --- /dev/null +++ b/src/daft-core/src/array/ops/square_sum.rs @@ -0,0 +1,37 @@ +use arrow2::array::PrimitiveArray; +use common_error::DaftResult; + +use crate::array::{ + ops::{DaftSquareSumAggable, GroupIndices}, + prelude::Float64Array, +}; + +impl DaftSquareSumAggable for Float64Array { + type Output = DaftResult; + + fn square_sum(&self) -> Self::Output { + let sum_square = self + .into_iter() + .flatten() + .copied() + .fold(0., |acc, value| acc + value.powi(2)); + let data = PrimitiveArray::from([Some(sum_square)]).boxed(); + let field = self.field.clone(); + Self::new(field, data) + } + + fn grouped_square_sum(&self, groups: &GroupIndices) -> Self::Output { + let grouped_square_sum_iter = groups + .iter() + .map(|group| { + group.iter().copied().fold(0., |acc, index| { + self.get(index as _) + .map_or(acc, |value| acc + value.powi(2)) + }) + }) + .map(Some); + let data = PrimitiveArray::from_trusted_len_iter(grouped_square_sum_iter).boxed(); + let field = self.field.clone(); + Self::new(field, data) + } +} diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs index 42692b62ba..c412922937 100644 --- a/src/daft-core/src/array/ops/stddev.rs +++ b/src/daft-core/src/array/ops/stddev.rs @@ -1,8 +1,5 @@ -use std::sync::Arc; - use arrow2::array::PrimitiveArray; use common_error::DaftResult; -use daft_schema::{dtype::DataType, field::Field}; use crate::{ array::{ @@ -20,7 +17,7 @@ impl DaftStddevAggable for DataArray { let stats = stats::calculate_stats(self)?; let values = self.into_iter().flatten().copied(); let stddev = stats::calculate_stddev(stats, values); - let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + let field = self.field.clone(); let data = PrimitiveArray::::from([stddev]).boxed(); Self::new(field, data) } @@ -30,7 +27,7 @@ impl DaftStddevAggable for DataArray { let values = group.iter().filter_map(|&index| self.get(index as _)); stats::calculate_stddev(stats, values) }); - let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + let field = self.field.clone(); let data = PrimitiveArray::::from_iter(grouped_stddevs_iter).boxed(); Self::new(field, data) } diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 79cfaa484a..20176990aa 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -4,7 +4,10 @@ use logical::Decimal128Array; use crate::{ array::{ - ops::{DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, GroupIndices}, + ops::{ + DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftSquareSumAggable, + DaftStddevAggable, DaftSumAggable, GroupIndices, + }, ListArray, }, count_mode::CountMode, @@ -26,12 +29,10 @@ impl Series { } pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftSumAggable, datatypes::DataType::*}; - match self.data_type() { // intX -> int64 (in line with numpy) - Int8 | Int16 | Int32 | Int64 => { - let casted = self.cast(&Int64)?; + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + let casted = self.cast(&DataType::Int64)?; match groups { Some(groups) => { Ok(DaftSumAggable::grouped_sum(&casted.i64()?, groups)?.into_series()) @@ -40,8 +41,8 @@ impl Series { } } // uintX -> uint64 (in line with numpy) - UInt8 | UInt16 | UInt32 | UInt64 => { - let casted = self.cast(&UInt64)?; + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + let casted = self.cast(&DataType::UInt64)?; match groups { Some(groups) => { Ok(DaftSumAggable::grouped_sum(&casted.u64()?, groups)?.into_series()) @@ -50,7 +51,7 @@ impl Series { } } // floatX -> floatX (in line with numpy) - Float32 => match groups { + DataType::Float32 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( &self.downcast::()?, groups, @@ -58,7 +59,7 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - Float64 => match groups { + DataType::Float64 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( &self.downcast::()?, groups, @@ -66,7 +67,7 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - Decimal128(_, _) => match groups { + DataType::Decimal128(_, _) => match groups { Some(groups) => Ok(Decimal128Array::new( Field { dtype: try_sum_supertype(self.data_type())?, @@ -94,13 +95,24 @@ impl Series { } } - pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*}; + pub fn square_sum(&self, groups: Option<&GroupIndices>) -> DaftResult { + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else( + || casted.square_sum(), + |groups| casted.grouped_square_sum(groups), + )? + .into_series(); + Ok(series) + } + pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { // Upcast all numeric types to float64 and compute approx_sketch. match self.data_type() { dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; + let casted = self.cast(&DataType::Float64)?; match groups { Some(groups) => Ok(DaftApproxSketchAggable::grouped_approx_sketch( &casted.f64()?, diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 034a557278..1529bdcb7f 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -109,6 +109,9 @@ pub enum AggExpr { #[display("sum({_0})")] Sum(ExprRef), + #[display("square_sum({_0})")] + SquareSum(ExprRef), + #[display("approx_percentile({}, percentiles={:?}, force_list_output={})", _0.child, _0.percentiles, _0.force_list_output)] ApproxPercentile(ApproxPercentileParams), @@ -127,9 +130,6 @@ pub enum AggExpr { #[display("stddev({_0})")] Stddev(ExprRef), - #[display("stddev_merge({_0})")] - StddevMerge(ExprRef), - #[display("min({_0})")] Min(ExprRef), @@ -171,13 +171,13 @@ impl AggExpr { match self { Self::Count(expr, ..) | Self::Sum(expr) + | Self::SquareSum(expr) | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) | Self::ApproxCountDistinct(expr) | Self::ApproxSketch(expr, _) | Self::MergeSketch(expr, _) | Self::Mean(expr) | Self::Stddev(expr) - | Self::StddevMerge(expr) | Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) @@ -197,6 +197,10 @@ impl AggExpr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_sum()")) } + Self::SquareSum(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.local_square_sum()")) + } Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, @@ -232,10 +236,6 @@ impl AggExpr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_stddev()")) } - Self::StddevMerge(expr) => { - let child_id = expr.semantic_id(schema); - FieldID::new(format!("{child_id}.local_stddev_merge()")) - } Self::Min(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_min()")) @@ -266,13 +266,13 @@ impl AggExpr { match self { Self::Count(expr, ..) | Self::Sum(expr) + | Self::SquareSum(expr) | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) | Self::ApproxCountDistinct(expr) | Self::ApproxSketch(expr, _) | Self::MergeSketch(expr, _) | Self::Mean(expr) | Self::Stddev(expr) - | Self::StddevMerge(expr) | Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) @@ -292,9 +292,9 @@ impl AggExpr { match self { Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode), Self::Sum(_) => Self::Sum(first_child()), + Self::SquareSum(_) => Self::SquareSum(first_child()), Self::Mean(_) => Self::Mean(first_child()), Self::Stddev(_) => Self::Stddev(first_child()), - Self::StddevMerge(_) => Self::StddevMerge(first_child()), Self::Min(_) => Self::Min(first_child()), Self::Max(_) => Self::Max(first_child()), Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls), @@ -325,7 +325,7 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Self::Sum(expr) => { + Self::Sum(expr) | Self::SquareSum(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), @@ -392,7 +392,7 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - Self::Mean(expr) | Self::Stddev(expr) | Self::StddevMerge(expr) => { + Self::Mean(expr) | Self::Stddev(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index 8a46faf694..75407f4af3 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -219,6 +219,7 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult { AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) } AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), + AggExpr::SquareSum(e) => AggExpr::SquareSum(e.alias(name.clone())), AggExpr::ApproxPercentile(ApproxPercentileParams { child: e, percentiles, @@ -239,9 +240,6 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult { } AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), - AggExpr::StddevMerge(e) => { - AggExpr::StddevMerge(Expr::Alias(e, name.clone()).into()) - } AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), AggExpr::AnyValue(e, ignore_nulls) => { diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 5a60bf0aa2..44c27270bf 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -380,6 +380,10 @@ fn replace_column_with_semantic_id_aggexpr( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Sum, |_| e) } + AggExpr::SquareSum(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::SquareSum, |_| e) + } AggExpr::ApproxPercentile(ApproxPercentileParams { ref child, ref percentiles, @@ -419,10 +423,6 @@ fn replace_column_with_semantic_id_aggexpr( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Stddev, |_| e) } - AggExpr::StddevMerge(ref child) => { - replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::StddevMerge, |_| e) - } AggExpr::Min(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Min, |_| e) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 00d9c2efaa..5252b690bc 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -814,6 +814,9 @@ pub fn populate_aggregation_stages( )); final_exprs.push(col(sum_of_sum_id.clone()).alias(output_name)); } + AggExpr::SquareSum(..) => { + unimplemented!("User-facing square_sum aggregation is not implemented") + } AggExpr::Mean(e) => { let sum_id = AggExpr::Sum(e.clone()).semantic_id(schema).id; let count_id = AggExpr::Count(e.clone(), CountMode::Valid) @@ -846,18 +849,50 @@ pub fn populate_aggregation_stages( ); } AggExpr::Stddev(sub_expr) => { - // first stage + // first stage aggregation let sum_expr = AggExpr::Sum(sub_expr.clone()); + let sq_sum_expr = AggExpr::SquareSum(sub_expr.clone()); let count_expr = AggExpr::Count(sub_expr.clone(), CountMode::Valid); - add_to_stage(&mut first_stage_aggs, get_id(&sum_expr), sum_expr); - add_to_stage(&mut first_stage_aggs, get_id(&count_expr), count_expr); + let sum_id = get_id(&sum_expr); + let sq_sum_id = get_id(&sq_sum_expr); + let count_id = get_id(&count_expr); + add_to_stage(&mut first_stage_aggs, sum_id.clone(), sum_expr); + add_to_stage(&mut first_stage_aggs, sq_sum_id.clone(), sq_sum_expr); + add_to_stage(&mut first_stage_aggs, count_id.clone(), count_expr); + + // second stage aggregation + let global_sum_expr = AggExpr::Sum(col(sum_id)); + let global_sq_sum_expr = AggExpr::Sum(col(sq_sum_id)); + let global_count_expr = AggExpr::Sum(col(count_id)); + let global_sum_id = get_id(&global_sum_expr); + let global_sq_sum_id = get_id(&global_sq_sum_expr); + let global_count_id = get_id(&global_count_expr); + add_to_stage( + &mut second_stage_aggs, + global_sum_id.clone(), + global_sum_expr, + ); + add_to_stage( + &mut second_stage_aggs, + global_sq_sum_id.clone(), + global_sq_sum_expr, + ); + add_to_stage( + &mut second_stage_aggs, + global_count_id.clone(), + global_count_expr, + ); - // second stage + // final projection + let g_sq_sum = col(global_sq_sum_id); + let g_sum = col(global_sum_id); + let g_count = col(global_count_id); + let left = g_sq_sum.div(g_count.clone()); + let right = g_sum.div(g_count); + let right = right.clone().mul(right); + let result = left.sub(right); - todo!("stddev") - } - AggExpr::StddevMerge(..) => { - unimplemented!("User-facing stddev_merge aggregation is not implemented") + final_exprs.push(result); } AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index ba03818583..a783075ce1 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -101,6 +101,7 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult unsupported_sql_err!("square_sum"), AggExpr::ApproxCountDistinct(_) => unsupported_sql_err!("approx_percentile"), AggExpr::ApproxPercentile(_) => unsupported_sql_err!("approx_percentile"), AggExpr::ApproxSketch(_, _) => unsupported_sql_err!("approx_sketch"), @@ -109,8 +110,10 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult todo!("stddev"), - AggExpr::StddevMerge(..) => todo!("stddev_merge"), + AggExpr::Stddev(_) => { + ensure!(args.len() == 1, "stddev takes exactly one argument"); + Ok(args[0].clone().stddev()) + } AggExpr::Min(_) => { ensure!(args.len() == 1, "min takes exactly one argument"); Ok(args[0].clone().min()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index f140f47927..e6411d8ce5 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -436,6 +436,7 @@ impl Table { match agg_expr { &AggExpr::Count(ref expr, mode) => self.eval_expression(expr)?.count(groups, mode), AggExpr::Sum(expr) => self.eval_expression(expr)?.sum(groups), + AggExpr::SquareSum(expr) => self.eval_expression(expr)?.square_sum(groups), &AggExpr::ApproxPercentile(ApproxPercentileParams { child: ref expr, ref percentiles, @@ -481,7 +482,6 @@ impl Table { } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups), - AggExpr::StddevMerge(..) => todo!("stddev merge"), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => { From d825a216636600d255da1e5bf8f83bce855d9794 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 01:05:14 -0700 Subject: [PATCH 14/28] Edit data-type of `square_sum` field in `to_field` impl --- src/daft-core/src/array/ops/square_sum.rs | 4 ++-- src/daft-dsl/src/expr/mod.rs | 6 +++++- src/daft-plan/src/physical_planner/translate.rs | 2 +- src/daft-table/src/lib.rs | 8 ++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/daft-core/src/array/ops/square_sum.rs b/src/daft-core/src/array/ops/square_sum.rs index 7f4b1baf7e..56abc4550e 100644 --- a/src/daft-core/src/array/ops/square_sum.rs +++ b/src/daft-core/src/array/ops/square_sum.rs @@ -14,7 +14,7 @@ impl DaftSquareSumAggable for Float64Array { .into_iter() .flatten() .copied() - .fold(0., |acc, value| acc + value.powi(2)); + .fold(0., |acc, value| value.mul_add(value, acc)); let data = PrimitiveArray::from([Some(sum_square)]).boxed(); let field = self.field.clone(); Self::new(field, data) @@ -26,7 +26,7 @@ impl DaftSquareSumAggable for Float64Array { .map(|group| { group.iter().copied().fold(0., |acc, index| { self.get(index as _) - .map_or(acc, |value| acc + value.powi(2)) + .map_or(acc, |value| value.mul_add(value, acc)) }) }) .map(Some); diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index c7127c4654..1c7a73c984 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -325,13 +325,17 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Self::Sum(expr) | Self::SquareSum(expr) => { + Self::Sum(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), try_sum_supertype(&field.dtype)?, )) } + Self::SquareSum(expr) => { + let field = expr.to_field(schema)?; + Ok(Field::new(field.name.as_str(), DataType::Float64)) + } Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index aebf2772e4..787ba8583c 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -890,7 +890,7 @@ pub fn populate_aggregation_stages( let left = g_sq_sum.div(g_count.clone()); let right = g_sum.div(g_count); let right = right.clone().mul(right); - let result = left.sub(right); + let result = left.sub(right).alias(output_name); final_exprs.push(result); } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 2b14ae750a..10972755a8 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -586,10 +586,10 @@ impl Table { assert!( !(expected_field.dtype != series.field().dtype), "Data type mismatch in expression evaluation:\n\ - Expected type: {}\n\ - Computed type: {}\n\ - Expression: {}\n\ - This likely indicates an internal error in type inference or computation.", + Expected type: {}\n\ + Computed type: {}\n\ + Expression: {}\n\ + This likely indicates an internal error in type inference or computation.", expected_field.dtype, series.field().dtype, expr From 9b946269f4eb272f78ccd6ce6ed47aeb8e57d44f Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 07:35:22 -0700 Subject: [PATCH 15/28] Fix errors in multi-partition aggregation planning --- .../src/physical_planner/translate.rs | 74 +++++++++++-------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 787ba8583c..d689d654be 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -11,6 +11,7 @@ use daft_core::prelude::*; use daft_dsl::{ col, is_partition_compatible, AggExpr, ApproxPercentileParams, ExprRef, SketchType, }; +use daft_functions::numeric::sqrt; use daft_scan::PhysicalScanInfo; use crate::{ @@ -774,15 +775,24 @@ pub fn populate_aggregation_stages( // Project the aggregation results to their final output names let mut final_exprs: Vec = group_by.iter().map(|e| col(e.name())).collect(); - let get_id = |expr: &AggExpr| expr.semantic_id(schema).id; - - fn add_to_stage(stage: &mut HashMap, AggExpr>, id: Arc, agg_expr: AggExpr) { + fn add_to_stage( + f: F, + expr: ExprRef, + schema: &Schema, + stage: &mut HashMap, AggExpr>, + ) -> Arc + where + F: Fn(ExprRef) -> AggExpr, + { + let id = f(expr.clone()).semantic_id(schema).id; + let agg_expr = f(expr.alias(id.clone())); let prev_agg_expr = stage.insert(id.clone(), agg_expr); assert!( prev_agg_expr.is_none(), "{:?} already exists in this stage but it should not", - id + &id ); + id } for agg_expr in aggregations { @@ -850,37 +860,43 @@ pub fn populate_aggregation_stages( } AggExpr::Stddev(sub_expr) => { // first stage aggregation - let sum_expr = AggExpr::Sum(sub_expr.clone()); - let sq_sum_expr = AggExpr::SquareSum(sub_expr.clone()); - let count_expr = AggExpr::Count(sub_expr.clone(), CountMode::Valid); - let sum_id = get_id(&sum_expr); - let sq_sum_id = get_id(&sq_sum_expr); - let count_id = get_id(&count_expr); - add_to_stage(&mut first_stage_aggs, sum_id.clone(), sum_expr); - add_to_stage(&mut first_stage_aggs, sq_sum_id.clone(), sq_sum_expr); - add_to_stage(&mut first_stage_aggs, count_id.clone(), count_expr); + let sum_id = add_to_stage( + AggExpr::Sum, + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + let sq_sum_id = add_to_stage( + AggExpr::SquareSum, + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + let count_id = add_to_stage( + |sub_expr| AggExpr::Count(sub_expr, CountMode::Valid), + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); // second stage aggregation - let global_sum_expr = AggExpr::Sum(col(sum_id)); - let global_sq_sum_expr = AggExpr::Sum(col(sq_sum_id)); - let global_count_expr = AggExpr::Sum(col(count_id)); - let global_sum_id = get_id(&global_sum_expr); - let global_sq_sum_id = get_id(&global_sq_sum_expr); - let global_count_id = get_id(&global_count_expr); - add_to_stage( + let global_sum_id = add_to_stage( + AggExpr::Sum, + col(sum_id.clone()), + schema, &mut second_stage_aggs, - global_sum_id.clone(), - global_sum_expr, ); - add_to_stage( + let global_sq_sum_id = add_to_stage( + AggExpr::Sum, + col(sq_sum_id.clone()), + schema, &mut second_stage_aggs, - global_sq_sum_id.clone(), - global_sq_sum_expr, ); - add_to_stage( + let global_count_id = add_to_stage( + AggExpr::Sum, + col(count_id.clone()), + schema, &mut second_stage_aggs, - global_count_id.clone(), - global_count_expr, ); // final projection @@ -890,7 +906,7 @@ pub fn populate_aggregation_stages( let left = g_sq_sum.div(g_count.clone()); let right = g_sum.div(g_count); let right = right.clone().mul(right); - let result = left.sub(right).alias(output_name); + let result = sqrt::sqrt(left.sub(right)).alias(output_name); final_exprs.push(result); } From 70577abd129c5ebfdc5537a937294839f65fc116 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 07:41:39 -0700 Subject: [PATCH 16/28] Add some tests for stddev (single- and multi- partitioned) --- tests/dataframe/test_stddev.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py index 044d5facc5..37a54c74e0 100644 --- a/tests/dataframe/test_stddev.py +++ b/tests/dataframe/test_stddev.py @@ -7,25 +7,36 @@ def stddev(nums) -> float: + def sum_reducer(acc, num): + return acc + num if num is not None else acc + + def stddev_reducer(acc, num): + return acc + (num - mean) ** 2 if num is not None else acc + if not nums: return 0.0 - sum_: float = sum(nums) + # sum_: float = sum(nums) + sum = functools.reduce(sum_reducer, nums, 0) count = len(nums) - mean = sum_ / count - squared_sums = functools.reduce(lambda acc, num: acc + (num - mean) ** 2, nums, 0) + mean = sum / count + + squared_sums = functools.reduce(stddev_reducer, nums, 0) stddev = math.sqrt(squared_sums / count) return stddev TESTS = [ [nums := [0], stddev(nums)], + [nums := [1], stddev(nums)], [nums := [0, 1, 2], stddev(nums)], [nums := [0, 0, 0], stddev(nums)], + [nums := [None, 0, None], stddev(nums)], + [nums := [None] * 10 + [0], stddev(nums)], ] @pytest.mark.parametrize("data_and_expected", TESTS) -def test_stddev(data_and_expected): +def test_stddev_with_single_partition(data_and_expected): data, expected = data_and_expected df = daft.from_pydict({"a": data}) result = df.agg(daft.col("a").stddev()).collect() From 265a7a7fcf84106e5f6abeb05bfa1bef96ae03f1 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 08:32:34 -0700 Subject: [PATCH 17/28] Finish tests for stddev feature --- tests/dataframe/test_stddev.py | 100 +++++++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 12 deletions(-) diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py index 37a54c74e0..6679893d3a 100644 --- a/tests/dataframe/test_stddev.py +++ b/tests/dataframe/test_stddev.py @@ -1,26 +1,38 @@ import functools import math +import pandas as pd import pytest import daft -def stddev(nums) -> float: - def sum_reducer(acc, num): - return acc + num if num is not None else acc +def grouped_stddev(rows) -> tuple[list[any], list[any]]: + map = {} + for key, data in rows: + if key not in map: + map[key] = [] + map[key].append(data) + + keys = [] + stddevs = [] + for key, nums in map.items(): + keys.append(key) + stddevs.append(stddev(nums)) + + return keys, stddevs - def stddev_reducer(acc, num): - return acc + (num - mean) ** 2 if num is not None else acc + +def stddev(nums) -> float: + nums = [num for num in nums if num is not None] if not nums: return 0.0 - # sum_: float = sum(nums) - sum = functools.reduce(sum_reducer, nums, 0) + sum_: float = sum(nums) count = len(nums) - mean = sum / count + mean = sum_ / count - squared_sums = functools.reduce(stddev_reducer, nums, 0) + squared_sums = functools.reduce(lambda acc, num: acc + (num - mean) ** 2, nums, 0) stddev = math.sqrt(squared_sums / count) return stddev @@ -29,9 +41,9 @@ def stddev_reducer(acc, num): [nums := [0], stddev(nums)], [nums := [1], stddev(nums)], [nums := [0, 1, 2], stddev(nums)], - [nums := [0, 0, 0], stddev(nums)], - [nums := [None, 0, None], stddev(nums)], - [nums := [None] * 10 + [0], stddev(nums)], + [nums := [100, 100, 100], stddev(nums)], + [nums := [None, 100, None], stddev(nums)], + [nums := [None] * 10 + [100], stddev(nums)], ] @@ -65,3 +77,67 @@ def test_stddev_with_multiple_partitions(data_and_expected): pass assert stddev["a"] == expected + + +GROUPED_TESTS = [ + [rows := [("k1", 0), ("k2", 1), ("k1", 1)], *grouped_stddev(rows)], + [rows := [("k0", 100), ("k1", 100), ("k2", 100)], *grouped_stddev(rows)], + [rows := [("k0", 100), ("k0", 100), ("k0", 100)], *grouped_stddev(rows)], + [rows := [("k0", 0), ("k0", 1), ("k0", 2)], *grouped_stddev(rows)], + [rows := [("k0", None), ("k0", None), ("k0", 100)], *grouped_stddev(rows)], +] + + +def unzip_rows(rows: list) -> tuple[list, list]: + keys = [] + nums = [] + for key, data in rows: + keys.append(key) + nums.append(data) + return keys, nums + + +@pytest.mark.parametrize("data_and_expected", GROUPED_TESTS) +def test_grouped_stddev_with_single_partition(data_and_expected): + nums, expected_keys, expected_stddevs = data_and_expected + expected_df = daft.from_pydict({"keys": expected_keys, "data": expected_stddevs}) + keys, data = unzip_rows(nums) + df = daft.from_pydict({"keys": keys, "data": data}) + result_df = df.groupby("keys").agg(daft.col("data").stddev()).collect() + + result = result_df.to_pydict() + expected = expected_df.to_pydict() + + pd.testing.assert_series_equal( + pd.Series(result["keys"]).sort_values(), + pd.Series(expected["keys"]).sort_values(), + check_index=False, + ) + pd.testing.assert_series_equal( + pd.Series(result["data"]).sort_values(), + pd.Series(expected["data"]).sort_values(), + check_index=False, + ) + + +@pytest.mark.parametrize("data_and_expected", GROUPED_TESTS) +def test_grouped_stddev_with_multiple_partitions(data_and_expected): + nums, expected_keys, expected_stddevs = data_and_expected + expected_df = daft.from_pydict({"keys": expected_keys, "data": expected_stddevs}) + keys, data = unzip_rows(nums) + df = daft.from_pydict({"keys": keys, "data": data}).into_partitions(2) + result_df = df.groupby("keys").agg(daft.col("data").stddev()).collect() + + result = result_df.to_pydict() + expected = expected_df.to_pydict() + + pd.testing.assert_series_equal( + pd.Series(result["keys"]).sort_values(), + pd.Series(expected["keys"]).sort_values(), + check_index=False, + ) + pd.testing.assert_series_equal( + pd.Series(result["data"]).sort_values(), + pd.Series(expected["data"]).sort_values(), + check_index=False, + ) From a76fade0a433ee27ea64d102bd2189a7c6381c02 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 08:40:29 -0700 Subject: [PATCH 18/28] Explicitly import typing module; fix lints --- tests/dataframe/test_stddev.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py index 6679893d3a..464d20bd41 100644 --- a/tests/dataframe/test_stddev.py +++ b/tests/dataframe/test_stddev.py @@ -1,5 +1,6 @@ import functools import math +from typing import Any, List, Tuple import pandas as pd import pytest @@ -7,7 +8,7 @@ import daft -def grouped_stddev(rows) -> tuple[list[any], list[any]]: +def grouped_stddev(rows) -> Tuple[List[Any], List[Any]]: map = {} for key, data in rows: if key not in map: @@ -88,7 +89,7 @@ def test_stddev_with_multiple_partitions(data_and_expected): ] -def unzip_rows(rows: list) -> tuple[list, list]: +def unzip_rows(rows: list) -> Tuple[List, List]: keys = [] nums = [] for key, data in rows: From 7a5a36ada4a88121683e8b68c2d6b7118e3d89c9 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 09:29:34 -0700 Subject: [PATCH 19/28] Remove `SquareSum` since it can just be implemented as `AggExpr::Sum(Expr::Agg(AggExpr::BinaryOp { .. }))` --- src/daft-core/src/array/ops/mod.rs | 7 ---- src/daft-core/src/array/ops/square_sum.rs | 37 ------------------- src/daft-core/src/series/ops/agg.rs | 17 +-------- src/daft-dsl/src/expr/mod.rs | 15 +------- src/daft-dsl/src/resolve_expr/mod.rs | 1 - src/daft-plan/src/logical_ops/project.rs | 4 -- .../src/physical_planner/translate.rs | 5 +-- src/daft-sql/src/modules/aggs.rs | 1 - src/daft-table/src/lib.rs | 1 - 9 files changed, 4 insertions(+), 84 deletions(-) delete mode 100644 src/daft-core/src/array/ops/square_sum.rs diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 41081fc118..3bcf0f0cb9 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -49,7 +49,6 @@ mod sketch_percentile; mod sort; pub(crate) mod sparse_tensor; mod sqrt; -mod square_sum; mod stddev; mod struct_; mod sum; @@ -173,12 +172,6 @@ pub trait DaftSumAggable { fn grouped_sum(&self, groups: &GroupIndices) -> Self::Output; } -pub trait DaftSquareSumAggable { - type Output; - fn square_sum(&self) -> Self::Output; - fn grouped_square_sum(&self, groups: &GroupIndices) -> Self::Output; -} - pub trait DaftApproxSketchAggable { type Output; fn approx_sketch(&self) -> Self::Output; diff --git a/src/daft-core/src/array/ops/square_sum.rs b/src/daft-core/src/array/ops/square_sum.rs deleted file mode 100644 index 56abc4550e..0000000000 --- a/src/daft-core/src/array/ops/square_sum.rs +++ /dev/null @@ -1,37 +0,0 @@ -use arrow2::array::PrimitiveArray; -use common_error::DaftResult; - -use crate::array::{ - ops::{DaftSquareSumAggable, GroupIndices}, - prelude::Float64Array, -}; - -impl DaftSquareSumAggable for Float64Array { - type Output = DaftResult; - - fn square_sum(&self) -> Self::Output { - let sum_square = self - .into_iter() - .flatten() - .copied() - .fold(0., |acc, value| value.mul_add(value, acc)); - let data = PrimitiveArray::from([Some(sum_square)]).boxed(); - let field = self.field.clone(); - Self::new(field, data) - } - - fn grouped_square_sum(&self, groups: &GroupIndices) -> Self::Output { - let grouped_square_sum_iter = groups - .iter() - .map(|group| { - group.iter().copied().fold(0., |acc, index| { - self.get(index as _) - .map_or(acc, |value| value.mul_add(value, acc)) - }) - }) - .map(Some); - let data = PrimitiveArray::from_trusted_len_iter(grouped_square_sum_iter).boxed(); - let field = self.field.clone(); - Self::new(field, data) - } -} diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index a4715afcb4..b3bfee765c 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -5,8 +5,8 @@ use logical::Decimal128Array; use crate::{ array::{ ops::{ - DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftSquareSumAggable, - DaftStddevAggable, DaftSumAggable, GroupIndices, + DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, + DaftSumAggable, GroupIndices, }, ListArray, }, @@ -95,19 +95,6 @@ impl Series { } } - pub fn square_sum(&self, groups: Option<&GroupIndices>) -> DaftResult { - self.data_type().assert_is_numeric()?; - let casted = self.cast(&DataType::Float64)?; - let casted = casted.f64()?; - let series = groups - .map_or_else( - || casted.square_sum(), - |groups| casted.grouped_square_sum(groups), - )? - .into_series(); - Ok(series) - } - pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { // Upcast all numeric types to float64 and compute approx_sketch. match self.data_type() { diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 1c7a73c984..3badc3c083 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -109,9 +109,6 @@ pub enum AggExpr { #[display("sum({_0})")] Sum(ExprRef), - #[display("square_sum({_0})")] - SquareSum(ExprRef), - #[display("approx_percentile({}, percentiles={:?}, force_list_output={})", _0.child, _0.percentiles, _0.force_list_output)] ApproxPercentile(ApproxPercentileParams), @@ -171,7 +168,6 @@ impl AggExpr { match self { Self::Count(expr, ..) | Self::Sum(expr) - | Self::SquareSum(expr) | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) | Self::ApproxCountDistinct(expr) | Self::ApproxSketch(expr, _) @@ -197,10 +193,6 @@ impl AggExpr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_sum()")) } - Self::SquareSum(expr) => { - let child_id = expr.semantic_id(schema); - FieldID::new(format!("{child_id}.local_square_sum()")) - } Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, @@ -266,7 +258,6 @@ impl AggExpr { match self { Self::Count(expr, ..) | Self::Sum(expr) - | Self::SquareSum(expr) | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) | Self::ApproxCountDistinct(expr) | Self::ApproxSketch(expr, _) @@ -292,7 +283,6 @@ impl AggExpr { match self { Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode), Self::Sum(_) => Self::Sum(first_child()), - Self::SquareSum(_) => Self::SquareSum(first_child()), Self::Mean(_) => Self::Mean(first_child()), Self::Stddev(_) => Self::Stddev(first_child()), Self::Min(_) => Self::Min(first_child()), @@ -332,10 +322,7 @@ impl AggExpr { try_sum_supertype(&field.dtype)?, )) } - Self::SquareSum(expr) => { - let field = expr.to_field(schema)?; - Ok(Field::new(field.name.as_str(), DataType::Float64)) - } + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index 0b3c59f056..5888774fe4 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -218,7 +218,6 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult { AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) } AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), - AggExpr::SquareSum(e) => AggExpr::SquareSum(e.alias(name.clone())), AggExpr::ApproxPercentile(ApproxPercentileParams { child: e, percentiles, diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 3b0ccb49c5..78de22bea6 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -375,10 +375,6 @@ fn replace_column_with_semantic_id_aggexpr( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Sum, |_| e) } - AggExpr::SquareSum(ref child) => { - replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::SquareSum, |_| e) - } AggExpr::ApproxPercentile(ApproxPercentileParams { ref child, ref percentiles, diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index d689d654be..6cb8ca6c6d 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -824,9 +824,6 @@ pub fn populate_aggregation_stages( )); final_exprs.push(col(sum_of_sum_id.clone()).alias(output_name)); } - AggExpr::SquareSum(..) => { - unimplemented!("User-facing square_sum aggregation is not implemented") - } AggExpr::Mean(e) => { let sum_id = AggExpr::Sum(e.clone()).semantic_id(schema).id; let count_id = AggExpr::Count(e.clone(), CountMode::Valid) @@ -867,7 +864,7 @@ pub fn populate_aggregation_stages( &mut first_stage_aggs, ); let sq_sum_id = add_to_stage( - AggExpr::SquareSum, + |sub_expr| AggExpr::Sum(sub_expr.clone().mul(sub_expr)), sub_expr.clone(), schema, &mut first_stage_aggs, diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index d7a6098e13..7e8ceb5fcb 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -101,7 +101,6 @@ pub fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult { ensure!(args.len() == 1, "sum takes exactly one argument"); Ok(args[0].clone().sum()) } - AggExpr::SquareSum(_) => unsupported_sql_err!("square_sum"), AggExpr::ApproxCountDistinct(_) => unsupported_sql_err!("approx_percentile"), AggExpr::ApproxPercentile(_) => unsupported_sql_err!("approx_percentile"), AggExpr::ApproxSketch(_, _) => unsupported_sql_err!("approx_sketch"), diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 10972755a8..cf96344a53 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -435,7 +435,6 @@ impl Table { match agg_expr { &AggExpr::Count(ref expr, mode) => self.eval_expression(expr)?.count(groups, mode), AggExpr::Sum(expr) => self.eval_expression(expr)?.sum(groups), - AggExpr::SquareSum(expr) => self.eval_expression(expr)?.square_sum(groups), &AggExpr::ApproxPercentile(ApproxPercentileParams { child: ref expr, ref percentiles, From c6eba4e4d49776184468fc9ffba0624b6c27c6c9 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 12:22:49 -0700 Subject: [PATCH 20/28] Add debug_assertions to length checking during stats calculations --- src/daft-core/src/utils/stats.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs index bf429a4d14..6be03c7282 100644 --- a/src/daft-core/src/utils/stats.rs +++ b/src/daft-core/src/utils/stats.rs @@ -34,8 +34,11 @@ pub fn grouped_stats<'a>( ) -> DaftResult> { let grouped_sum = array.grouped_sum(groups)?; let grouped_count = array.grouped_count(groups, CountMode::Valid)?; - assert_eq!(grouped_sum.len(), grouped_count.len()); - assert_eq!(grouped_sum.len(), groups.len()); + #[cfg(debug_assertions)] + { + assert_eq!(grouped_sum.len(), grouped_count.len()); + assert_eq!(grouped_sum.len(), groups.len()); + }; Ok(GroupedStats { grouped_sum, grouped_count, From 4581104c3927df1384addd7d322ff011ab22ee9d Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 12:25:47 -0700 Subject: [PATCH 21/28] Remove dead function and remove re-calculation of mean --- src/daft-core/src/array/ops/mean.rs | 3 +-- src/daft-dsl/src/expr/mod.rs | 15 --------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index 1bc659279a..09307d3dbc 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -17,8 +17,7 @@ impl DaftMeanAggable for DataArray { fn mean(&self) -> Self::Output { let stats = stats::calculate_stats(self)?; - let mean = stats::calculate_mean(stats.sum, stats.count); - let data = PrimitiveArray::from([mean]).boxed(); + let data = PrimitiveArray::from([stats.mean]).boxed(); let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); Self::new(field, data) } diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 3badc3c083..949eccea62 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -411,21 +411,6 @@ impl AggExpr { Self::MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), } } - - pub fn from_name_and_child_expr(name: &str, child: ExprRef) -> DaftResult { - match name { - "count" => Ok(Self::Count(child, CountMode::Valid)), - "sum" => Ok(Self::Sum(child)), - "mean" => Ok(Self::Mean(child)), - "min" => Ok(Self::Min(child)), - "max" => Ok(Self::Max(child)), - "list" => Ok(Self::List(child)), - _ => Err(DaftError::ValueError(format!( - "{} not a valid aggregation name", - name - ))), - } - } } impl From<&AggExpr> for ExprRef { From dd941b07394f51d21a334b82207e4b88771a00ab Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 12:31:51 -0700 Subject: [PATCH 22/28] Change type of count to f64 to avoid casts in loop; remove panic assertion on re-insertion of id - it is possible to have an existing key already in the map; thus shouldn't panic - keeping the count as a u64 would require casting to f64 in the loop, which leads to poor performance - instead store it as an f64 eagerly --- src/daft-core/src/utils/stats.rs | 8 ++++---- src/daft-plan/src/physical_planner/translate.rs | 7 +------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs index 6be03c7282..f5d1737a8e 100644 --- a/src/daft-core/src/utils/stats.rs +++ b/src/daft-core/src/utils/stats.rs @@ -11,7 +11,7 @@ use crate::{ #[derive(Clone, Copy, Default, Debug)] pub struct Stats { pub sum: f64, - pub count: u64, + pub count: f64, pub mean: Option, } @@ -22,7 +22,7 @@ pub fn calculate_stats(array: &Float64Array) -> DaftResult { .zip(count) .map_or_else(Default::default, |(sum, count)| Stats { sum, - count, + count: count as _, mean: calculate_mean(sum, count), }); Ok(stats) @@ -63,7 +63,7 @@ impl<'a, I: Iterator> Iterator for GroupedStats< .zip(count) .map_or_else(Default::default, |(sum, count)| Stats { sum, - count, + count: count as _, mean: calculate_mean(sum, count), }); Some((stats, group)) @@ -80,6 +80,6 @@ pub fn calculate_mean(sum: f64, count: u64) -> Option { pub fn calculate_stddev(stats: Stats, values: impl Iterator) -> Option { stats.mean.map(|mean| { let sum_of_squares = values.map(|value| (value - mean).powi(2)).sum::(); - (sum_of_squares / stats.count as f64).sqrt() + (sum_of_squares / stats.count).sqrt() }) } diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 6cb8ca6c6d..3514599bfe 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -786,12 +786,7 @@ pub fn populate_aggregation_stages( { let id = f(expr.clone()).semantic_id(schema).id; let agg_expr = f(expr.alias(id.clone())); - let prev_agg_expr = stage.insert(id.clone(), agg_expr); - assert!( - prev_agg_expr.is_none(), - "{:?} already exists in this stage but it should not", - &id - ); + stage.insert(id.clone(), agg_expr); id } From 823e3aff057d56c51f6d045d573a56421eca2e7f Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 12:34:52 -0700 Subject: [PATCH 23/28] Change name of data-type function --- src/daft-core/src/datatypes/agg_ops.rs | 2 +- src/daft-core/src/datatypes/mod.rs | 2 +- src/daft-dsl/src/expr/mod.rs | 4 ++-- src/daft-functions/src/list/mean.rs | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs index 53f0f19536..c1f04fecbe 100644 --- a/src/daft-core/src/datatypes/agg_ops.rs +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -23,7 +23,7 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { } /// Get the data type that the mean of a column of the given data type should be casted to. -pub fn try_numeric_aggregation_supertype(dtype: &DataType) -> DaftResult { +pub fn try_mean_stddev_aggregation_supertype(dtype: &DataType) -> DaftResult { if dtype.is_numeric() { Ok(DataType::Float64) } else { diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 36a010a9bd..01a6b6ca6e 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -6,7 +6,7 @@ pub use infer_datatype::InferDataType; pub mod prelude; use std::ops::{Add, Div, Mul, Rem, Sub}; -pub use agg_ops::{try_numeric_aggregation_supertype, try_sum_supertype}; +pub use agg_ops::{try_mean_stddev_aggregation_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 949eccea62..873f9013bd 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -10,7 +10,7 @@ use common_error::{DaftError, DaftResult}; use common_hashable_float_wrapper::FloatWrapper; use common_treenode::TreeNode; use daft_core::{ - datatypes::{try_numeric_aggregation_supertype, try_sum_supertype, InferDataType}, + datatypes::{try_mean_stddev_aggregation_supertype, try_sum_supertype, InferDataType}, prelude::*, utils::supertype::try_get_supertype, }; @@ -387,7 +387,7 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - try_numeric_aggregation_supertype(&field.dtype)?, + try_mean_stddev_aggregation_supertype(&field.dtype)?, )) } Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) => { diff --git a/src/daft-functions/src/list/mean.rs b/src/daft-functions/src/list/mean.rs index e42a9dd750..b01d3c1fa1 100644 --- a/src/daft-functions/src/list/mean.rs +++ b/src/daft-functions/src/list/mean.rs @@ -1,6 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::{ - datatypes::try_numeric_aggregation_supertype, + datatypes::try_mean_stddev_aggregation_supertype, prelude::{Field, Schema}, series::Series, }; @@ -29,7 +29,7 @@ impl ScalarUDF for ListMean { let inner_field = input.to_field(schema)?.to_exploded_field()?; Ok(Field::new( inner_field.name.as_str(), - try_numeric_aggregation_supertype(&inner_field.dtype)?, + try_mean_stddev_aggregation_supertype(&inner_field.dtype)?, )) } _ => Err(DaftError::SchemaMismatch(format!( From 53a0566a3421a5cd7e1403028c836412f96e1d03 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 13:33:51 -0700 Subject: [PATCH 24/28] Add comment to `populate_aggregation_stages`; explains what each agg-stage is doing --- src/daft-plan/src/physical_planner/translate.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 3514599bfe..c7a364c770 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -851,6 +851,14 @@ pub fn populate_aggregation_stages( ); } AggExpr::Stddev(sub_expr) => { + // The stddev calculation we're performing here is: + // stddev(X) = sqrt(E(X^2) - E(X)^2) + // where X is the sub_expr. + // + // First stage, we compute `sum(X^2)`, `sum(X)` and `count(X)`. + // Second stage, we `global_sqsum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage. + // In the final projection, we then compute `sqrt((global_sqsum / global_count) - (global_sum / global_count) ^ 2)`. + // first stage aggregation let sum_id = add_to_stage( AggExpr::Sum, From e4222f51ab2b1f5bbbbad91a75b6b3a7114293da Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 13:51:01 -0700 Subject: [PATCH 25/28] Add docs to dataframe stddev API --- daft/dataframe/dataframe.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 114c4a598f..2408890d7b 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2122,6 +2122,22 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": def stddev(self, *cols: ColumnInputType) -> "DataFrame": """Performs a global standard deviation on the DataFrame + Example: + >>> import daft + >>> df = daft.from_pydict({"col_a":[0,1,2]}) + >>> df = df.stddev("col_a") + >>> df.show() + ╭───────────────────╮ + │ col_a │ + │ --- │ + │ Float64 │ + ╞═══════════════════╡ + │ 0.816496580927726 │ + ╰───────────────────╯ + + (Showing first 1 of 1 rows) + + Args: *cols (Union[str, Expression]): columns to stddev Returns: @@ -2870,6 +2886,23 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": def stddev(self, *cols: ColumnInputType) -> "DataFrame": """Performs grouped standard deviation on this GroupedDataFrame. + Example: + >>> import daft + >>> df = daft.from_pydict({"keys": ["a", "a", "a", "b"], "col_a": [0,1,2,100]}) + >>> df = df.groupby("keys").stddev() + >>> df.show() + ╭──────┬───────────────────╮ + │ keys ┆ col_a │ + │ --- ┆ --- │ + │ Utf8 ┆ Float64 │ + ╞══════╪═══════════════════╡ + │ a ┆ 0.816496580927726 │ + ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b ┆ 0 │ + ╰──────┴───────────────────╯ + + (Showing first 2 of 2 rows) + Args: *cols (Union[str, Expression]): columns to stddev From 0c976a42ab91b0519412d60ec14cc7f6b17374cf Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 13:56:06 -0700 Subject: [PATCH 26/28] Change `assert_eq` to `debug_assert_eq` --- src/daft-core/src/utils/stats.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs index f5d1737a8e..de43b186ea 100644 --- a/src/daft-core/src/utils/stats.rs +++ b/src/daft-core/src/utils/stats.rs @@ -34,11 +34,8 @@ pub fn grouped_stats<'a>( ) -> DaftResult> { let grouped_sum = array.grouped_sum(groups)?; let grouped_count = array.grouped_count(groups, CountMode::Valid)?; - #[cfg(debug_assertions)] - { - assert_eq!(grouped_sum.len(), grouped_count.len()); - assert_eq!(grouped_sum.len(), groups.len()); - }; + debug_assert_eq!(grouped_sum.len(), grouped_count.len()); + debug_assert_eq!(grouped_sum.len(), groups.len()); Ok(GroupedStats { grouped_sum, grouped_count, From 65e443a9031360aca772fbffde296f3294aa1518 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 14:01:14 -0700 Subject: [PATCH 27/28] Update grouped-mean impl to use stats --- src/daft-core/src/array/ops/mean.rs | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index 09307d3dbc..d5764c4954 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -4,10 +4,7 @@ use arrow2::array::PrimitiveArray; use common_error::DaftResult; use crate::{ - array::ops::{ - as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable, GroupIndices, - }, - count_mode::CountMode, + array::ops::{DaftMeanAggable, GroupIndices}, datatypes::*, utils::stats, }; @@ -23,18 +20,8 @@ impl DaftMeanAggable for DataArray { } fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output { - let sum_values = self.grouped_sum(groups)?; - let count_values = self.grouped_count(groups, CountMode::Valid)?; - assert_eq!(sum_values.len(), count_values.len()); - let mean_per_group = sum_values - .as_arrow() - .values_iter() - .zip(count_values.as_arrow().values_iter()) - .map(|(s, c)| match (s, c) { - (_, 0) => None, - (s, c) => Some(s / (*c as f64)), - }); - let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group)); - Ok(Self::from((self.field.name.as_ref(), mean_array))) + let grouped_means = stats::grouped_stats(self, groups)?.map(|(stats, _)| stats.mean); + let data = Box::new(PrimitiveArray::from_iter(grouped_means)); + Ok(Self::from((self.field.name.as_ref(), data))) } } From 1874f43825fb0187f78099b7f4ce3b7415070506 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 14:16:39 -0700 Subject: [PATCH 28/28] Add to docs --- docs/source/api_docs/dataframe.rst | 1 + docs/source/api_docs/expressions.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/api_docs/dataframe.rst b/docs/source/api_docs/dataframe.rst index f93f052742..14a4e9fa20 100644 --- a/docs/source/api_docs/dataframe.rst +++ b/docs/source/api_docs/dataframe.rst @@ -104,6 +104,7 @@ Aggregations DataFrame.groupby DataFrame.sum DataFrame.mean + DataFrame.stddev DataFrame.count DataFrame.min DataFrame.max diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index ec86e0bb5e..a53ef825fd 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -113,6 +113,7 @@ The following can be used with DataFrame.agg or GroupedDataFrame.agg Expression.count Expression.sum Expression.mean + Expression.stddev Expression.min Expression.max Expression.any_value