Skip to content

Commit

Permalink
Add regr_slope() aggregate function (#7135)
Browse files Browse the repository at this point in the history
  • Loading branch information
2010YOUY01 committed Aug 1, 2023
1 parent 7af25de commit a9561a0
Show file tree
Hide file tree
Showing 13 changed files with 550 additions and 29 deletions.
158 changes: 158 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2290,3 +2290,161 @@ true
false
true
NULL



#
# regr_slope() tests
#

# invalid input
statement error
select regr_slope();

statement error
select regr_slope(*);

statement error
select regr_slope(*) from aggregate_test_100;

statement error
select regr_slope(1);

statement error
select regr_slope(1,2,3);

statement error
select regr_slope(1, 'foo');

statement error
select regr_slope('foo', 1);

statement error
select regr_slope('foo', 'bar');



# regr_slope() NULL result
query R
select regr_slope(1,1);
----
NULL

query R
select regr_slope(1, NULL);
----
NULL

query R
select regr_slope(NULL, 1);
----
NULL

query R
select regr_slope(NULL, NULL);
----
NULL

query R
select regr_slope(column2, column1) from (values (1,2), (1,4), (1,6));
----
NULL



# regr_slope() basic tests
query R
select regr_slope(column2, column1) from (values (1,2), (2,4), (3,6));
----
2

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628



# regr_slope() ignore NULLs
query R
select regr_slope(column2, column1) from (values (1,NULL), (2,4), (3,6));
----
2

query R
select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), (3,6));
----
NULL

query R
select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), (NULL,NULL));
----
NULL

query TR rowsort
select column3, regr_slope(column2, column1)
from (values (1,2,'a'), (2,4,'a'), (1,3,'b'), (3,9,'b'), (1,10,'c'), (NULL,100,'c'))
group by column3;
----
a 2
b 3
c NULL



# regr_slope() testing merge_batch() from RegrSlopeAccumulator's internal implementation
statement ok
set datafusion.execution.batch_size = 1;

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628

statement ok
set datafusion.execution.batch_size = 2;

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628

statement ok
set datafusion.execution.batch_size = 3;

query R
select regr_slope(c12, c11) from aggregate_test_100;
----
0.051534002628

statement ok
set datafusion.execution.batch_size = 8192;



# regr_slope testing retract_batch() from RegrSlopeAccumulator's internal implementation
query R
select regr_slope(column2, column1)
over (order by column1 rows between 2 preceding and current row)
from (values (1,2), (2,4), (3,6), (4,12), (5,15), (6, 18));
----
NULL
2
2
4
4.5
3

query R
select regr_slope(column2, column1)
over (order by column1 rows between 2 preceding and current row)
from (values (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7, 21));
----
NULL
2
2
2
NULL
NULL
3
3
13 changes: 9 additions & 4 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub enum AggregateFunction {
CovariancePop,
/// Correlation
Correlation,
/// Slope from linear regression
RegrSlope,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
Expand Down Expand Up @@ -102,6 +104,7 @@ impl AggregateFunction {
Covariance => "COVARIANCE",
CovariancePop => "COVARIANCE_POP",
Correlation => "CORRELATION",
RegrSlope => "REGR_SLOPE",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
ApproxMedian => "APPROX_MEDIAN",
Expand Down Expand Up @@ -152,6 +155,7 @@ impl FromStr for AggregateFunction {
"var" => AggregateFunction::Variance,
"var_pop" => AggregateFunction::VariancePop,
"var_samp" => AggregateFunction::Variance,
"regr_slope" => AggregateFunction::RegrSlope,
// approximate
"approx_distinct" => AggregateFunction::ApproxDistinct,
"approx_median" => AggregateFunction::ApproxMedian,
Expand Down Expand Up @@ -228,6 +232,7 @@ impl AggregateFunction {
}
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::RegrSlope => Ok(DataType::Float64),
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
Expand Down Expand Up @@ -311,10 +316,10 @@ impl AggregateFunction {
| AggregateFunction::LastValue => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::Correlation => {
AggregateFunction::Covariance
| AggregateFunction::CovariancePop
| AggregateFunction::Correlation
| AggregateFunction::RegrSlope => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
Expand Down
35 changes: 10 additions & 25 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Variance => {
AggregateFunction::Variance | AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
Expand All @@ -157,16 +157,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Covariance => {
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
Expand All @@ -175,16 +166,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Stddev => {
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
Expand All @@ -193,17 +175,20 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
AggregateFunction::RegrSlope => {
let valid_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat();
let input_types_valid = // number of input already checked before
valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]);
if !input_types_valid {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
Expand Down
11 changes: 11 additions & 0 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ pub fn create_aggregate_expr(
"CORR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::RegrSlope, false) => Arc::new(expressions::RegrSlope::new(
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
name,
rt_type,
)),
(AggregateFunction::RegrSlope, true) => {
return Err(DataFusionError::NotImplemented(
"REGR_SLOPE(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::ApproxPercentileCont, false) => {
if input_phy_exprs.len() == 2 {
Arc::new(expressions::ApproxPercentileCont::new(
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub mod build_in;
pub(crate) mod groups_accumulator;
mod hyperloglog;
pub mod moving_min_max;
pub(crate) mod regr_slope;
pub(crate) mod stats;
pub(crate) mod stddev;
pub(crate) mod sum;
Expand Down
Loading

0 comments on commit a9561a0

Please sign in to comment.