Skip to content

Commit

Permalink
Add functionality to differentiate between user input of 0.5 vs [0.5]
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed May 1, 2024
1 parent 6e3a17e commit 77e5fd3
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 67 deletions.
9 changes: 6 additions & 3 deletions src/daft-core/src/array/ops/approx_sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ impl DaftApproxSketchAggable for &DataArray<Float64Type> {

fn approx_sketch(&self) -> Self::Output {
let primitive_arr = self.as_arrow();
let arrow_array = if primitive_arr.null_count() > 0 {
let arrow_array = if primitive_arr.is_empty() {
daft_sketch::into_arrow2(vec![])
} else if primitive_arr.null_count() > 0 {
let sketch = primitive_arr
.iter()
.fold(None, |acc, value| match (acc, value) {
Expand All @@ -29,7 +31,6 @@ impl DaftApproxSketchAggable for &DataArray<Float64Type> {
Some(acc)
}
});

daft_sketch::into_arrow2(vec![sketch])
} else {
let sketch = primitive_arr.values_iter().fold(
Expand All @@ -55,7 +56,9 @@ impl DaftApproxSketchAggable for &DataArray<Float64Type> {

fn grouped_approx_sketch(&self, groups: &GroupIndices) -> Self::Output {
let arrow_array = self.as_arrow();
let sketch_per_group = if arrow_array.null_count() > 0 {
let sketch_per_group = if arrow_array.is_empty() {
daft_sketch::into_arrow2(vec![])
} else if arrow_array.null_count() > 0 {
let sketches: Vec<Option<DDSketch>> = groups
.iter()
.map(|g| {
Expand Down
18 changes: 11 additions & 7 deletions src/daft-core/src/array/ops/sketch_percentile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ use common_error::DaftResult;
use super::from_arrow::FromArrow;

impl StructArray {
pub fn sketch_percentile(&self, percentiles: &[f64]) -> DaftResult<Series> {
let output_len = percentiles.len();
let output_dtype = DataType::FixedSizeList(Box::new(DataType::Float64), output_len);
pub fn sketch_percentile(
&self,
percentiles: &[f64],
force_list_output: bool,
) -> DaftResult<Series> {
let output_dtype = DataType::FixedSizeList(Box::new(DataType::Float64), percentiles.len());
let output_field = Field::new(self.field.name.as_str(), output_dtype);

let mut flat_child = MutablePrimitiveArray::<f64>::with_capacity(output_len * self.len());
let mut flat_child =
MutablePrimitiveArray::<f64>::with_capacity(percentiles.len() * self.len());
daft_sketch::from_arrow2(self.to_arrow())?
.iter()
.for_each(|sketch| match sketch {
Expand All @@ -36,13 +40,13 @@ impl StructArray {
)?
.into_series();

if output_len == 1 {
Ok(flat_child)
} else {
if percentiles.len() > 1 || force_list_output {
Ok(
FixedSizeListArray::new(output_field, flat_child, self.validity().cloned())
.into_series(),
)
} else {
Ok(flat_child)
}
}
}
10 changes: 8 additions & 2 deletions src/daft-core/src/series/ops/sketch_percentile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ use common_error::DaftError;
use common_error::DaftResult;

impl Series {
pub fn sketch_percentile(&self, percentiles: &[f64]) -> DaftResult<Series> {
pub fn sketch_percentile(
&self,
percentiles: &[f64],
force_list_output: bool,
) -> DaftResult<Series> {
use crate::datatypes::DataType::*;

match self.data_type() {
Struct(_) => Ok(self.struct_()?.sketch_percentile(percentiles)?),
Struct(_) => Ok(self
.struct_()?
.sketch_percentile(percentiles, force_list_output)?),
other => Err(DaftError::TypeError(format!(
"sketch_percentile is not implemented for type {}",
other
Expand Down
50 changes: 33 additions & 17 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub enum Expr {
pub struct ApproxPercentileParams {
pub child: ExprRef,
pub percentiles: Vec<daft_core::utils::hashable_float_wrapper::FloatWrapper<f64>>,
pub force_list_output: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -127,10 +128,11 @@ impl AggExpr {
ApproxPercentile(ApproxPercentileParams {
child: expr,
percentiles,
force_list_output,
}) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!(
"{child_id}.local_approx_percentiles(percentiles={:?})",
"{child_id}.local_approx_percentiles(percentiles={:?},force_list_output={force_list_output})",
percentiles,
))
}
Expand Down Expand Up @@ -207,12 +209,15 @@ impl AggExpr {
func: func.clone(),
inputs: children,
},
ApproxPercentile(ApproxPercentileParams { percentiles, .. }) => {
ApproxPercentile(ApproxPercentileParams {
child: children[0].clone(),
percentiles: percentiles.clone(),
})
}
ApproxPercentile(ApproxPercentileParams {
percentiles,
force_list_output,
..
}) => ApproxPercentile(ApproxPercentileParams {
child: children[0].clone(),
percentiles: percentiles.clone(),
force_list_output: *force_list_output,
}),
ApproxSketch(_) => ApproxSketch(children[0].clone()),
MergeSketch(_) => MergeSketch(children[0].clone()),
}
Expand Down Expand Up @@ -259,15 +264,16 @@ impl AggExpr {
ApproxPercentile(ApproxPercentileParams {
child: expr,
percentiles,
force_list_output,
}) => {
let field = expr.to_field(schema)?;
Ok(Field::new(
field.name.as_str(),
match &field.dtype {
dt if dt.is_numeric() => if percentiles.len() == 1 {
DataType::Float64
} else {
dt if dt.is_numeric() => if percentiles.len() > 1 || *force_list_output {
DataType::FixedSizeList(Box::new(DataType::Float64), percentiles.len())
} else {
DataType::Float64
},
other => {
return Err(DaftError::TypeError(format!(
Expand Down Expand Up @@ -378,22 +384,32 @@ impl Expr {
Expr::Agg(AggExpr::ApproxSketch(self)).into()
}

pub fn approx_percentiles(self: ExprRef, percentiles: &[f64]) -> ExprRef {
pub fn approx_percentiles(
self: ExprRef,
percentiles: &[f64],
force_list_output: bool,
) -> ExprRef {
Expr::Agg(AggExpr::ApproxPercentile(ApproxPercentileParams {
child: self,
percentiles: percentiles
.iter()
.map(|f| daft_core::utils::hashable_float_wrapper::FloatWrapper(*f))
.collect(),
force_list_output,
}))
.into()
}

pub fn sketch_percentile(self: ExprRef, percentiles: &[f64]) -> ExprRef {
pub fn sketch_percentile(
self: ExprRef,
percentiles: &[f64],
force_list_output: bool,
) -> ExprRef {
Expr::Function {
func: FunctionExpr::Sketch(SketchExpr::Percentile(HashableVecPercentiles(
percentiles.to_vec(),
))),
func: FunctionExpr::Sketch(SketchExpr::Percentile {
percentiles: HashableVecPercentiles(percentiles.to_vec()),
force_list_output,
}),
inputs: vec![self],
}
.into()
Expand Down Expand Up @@ -890,9 +906,9 @@ impl Display for AggExpr {
Count(expr, mode) => write!(f, "count({expr}, {mode})"),
Sum(expr) => write!(f, "sum({expr})"),
ApproxSketch(expr) => write!(f, "approx_sketch({expr})"),
ApproxPercentile(ApproxPercentileParams { child, percentiles }) => write!(
ApproxPercentile(ApproxPercentileParams { child, percentiles, force_list_output }) => write!(
f,
"approx_percentiles({child}, percentiles={percentiles:?})"
"approx_percentiles({child}, percentiles={percentiles:?}, force_list_output={force_list_output})"
),
MergeSketch(expr) => write!(f, "merge_sketch({expr})"),
Mean(expr) => write!(f, "mean({expr})"),
Expand Down
16 changes: 10 additions & 6 deletions src/daft-dsl/src/functions/sketch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,28 @@ impl Eq for HashableVecPercentiles {}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum SketchExpr {
Percentile(HashableVecPercentiles),
Percentile {
percentiles: HashableVecPercentiles,
force_list_output: bool,
},
}

impl SketchExpr {
#[inline]
pub fn get_evaluator(&self) -> &dyn FunctionEvaluator {
use SketchExpr::*;
match self {
Percentile(_) => &PercentileEvaluator {},
Percentile { .. } => &PercentileEvaluator {},
}
}
}

pub fn sketch_percentile(input: ExprRef, percentiles: &[f64]) -> ExprRef {
pub fn sketch_percentile(input: ExprRef, percentiles: &[f64], force_list_output: bool) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Sketch(SketchExpr::Percentile(HashableVecPercentiles(
percentiles.to_vec(),
))),
func: super::FunctionExpr::Sketch(SketchExpr::Percentile {
percentiles: HashableVecPercentiles(percentiles.to_vec()),
force_list_output,
}),
inputs: vec![input.clone()],
}
.into()
Expand Down
20 changes: 12 additions & 8 deletions src/daft-dsl/src/functions/sketch/percentile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@ impl FunctionEvaluator for PercentileEvaluator {
match (input_field.dtype, expr) {
(
DataType::Struct(_),
FunctionExpr::Sketch(SketchExpr::Percentile(percentiles)),
FunctionExpr::Sketch(SketchExpr::Percentile {
percentiles,
force_list_output,
}),
) => Ok(Field::new(
input_field.name,
if percentiles.0.len() == 1 {
DataType::Float64
} else {
if percentiles.0.len() > 1 || *force_list_output {
DataType::FixedSizeList(
Box::new(DataType::Float64),
percentiles.0.len(),
)
} else {
DataType::Float64
},
)),
(input_field_dtype, FunctionExpr::Sketch(SketchExpr::Percentile(_))) => {
(input_field_dtype, FunctionExpr::Sketch(SketchExpr::Percentile { .. })) => {
Err(DaftError::TypeError(format!(
"Expected input to be a struct type, received: {}",
input_field_dtype
Expand All @@ -59,9 +62,10 @@ impl FunctionEvaluator for PercentileEvaluator {
fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[input] => match expr {
FunctionExpr::Sketch(SketchExpr::Percentile(percentiles)) => {
input.sketch_percentile(percentiles.0.as_slice())
}
FunctionExpr::Sketch(SketchExpr::Percentile {
percentiles,
force_list_output,
}) => input.sketch_percentile(percentiles.0.as_slice(), *force_list_output),
_ => unreachable!(
"PercentileEvaluator must evaluate a SketchExpr::Percentile expression"
),
Expand Down
8 changes: 4 additions & 4 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ impl PyExpr {
}

pub fn approx_percentiles(&self, percentiles: ApproxPercentileInput) -> PyResult<Self> {
let percentiles = match percentiles {
ApproxPercentileInput::Single(p) => vec![p],
ApproxPercentileInput::Many(p) => p,
let (percentiles, list_output) = match percentiles {
ApproxPercentileInput::Single(p) => (vec![p], false),
ApproxPercentileInput::Many(p) => (p, true),
};

for &p in percentiles.iter() {
Expand All @@ -319,7 +319,7 @@ impl PyExpr {
Ok(self
.expr
.clone()
.approx_percentiles(percentiles.as_slice())
.approx_percentiles(percentiles.as_slice(), list_output)
.into())
}

Expand Down
2 changes: 2 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult<daft_dsl::AggExpr> {
ApproxPercentile(ApproxPercentileParams {
child: e,
percentiles,
force_list_output,
}) => ApproxPercentile(ApproxPercentileParams {
child: Alias(e, name.clone()).into(),
percentiles,
force_list_output,
}),
MergeSketch(e) => MergeSketch(Alias(e, name.clone()).into()),
Mean(e) => Mean(Alias(e, name.clone()).into()),
Expand Down
2 changes: 2 additions & 0 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,14 @@ fn replace_column_with_semantic_id_aggexpr(
AggExpr::ApproxPercentile(ApproxPercentileParams {
ref child,
ref percentiles,
ref force_list_output,
}) => replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(
|transformed_child| {
AggExpr::ApproxPercentile(ApproxPercentileParams {
child: transformed_child,
percentiles: percentiles.clone(),
force_list_output: *force_list_output,
})
},
|_| e.clone(),
Expand Down
6 changes: 5 additions & 1 deletion src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ pub(super) fn translate_single_logical_node(
ApproxPercentile(ApproxPercentileParams {
child: e,
percentiles,
force_list_output,
}) => {
let percentiles =
percentiles.iter().map(|p| p.0).collect::<Vec<f64>>();
Expand All @@ -334,7 +335,10 @@ pub(super) fn translate_single_logical_node(
));
final_exprs.push(
col(approx_id.clone())
.sketch_percentile(percentiles.as_slice())
.sketch_percentile(
percentiles.as_slice(),
*force_list_output,
)
.alias(output_name),
);
}
Expand Down
3 changes: 2 additions & 1 deletion src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,11 @@ impl Table {
ApproxPercentile(ApproxPercentileParams {
child: expr,
percentiles,
force_list_output,
}) => {
let percentiles = percentiles.iter().map(|p| p.0).collect::<Vec<f64>>();
Series::approx_sketch(&self.eval_expression(expr)?, groups)?
.sketch_percentile(&percentiles)
.sketch_percentile(&percentiles, *force_list_output)
}
MergeSketch(expr) => Series::merge_sketch(&self.eval_expression(expr)?, groups),
Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups),
Expand Down
Loading

0 comments on commit 77e5fd3

Please sign in to comment.